{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(3, 192, 192, 3)\n",
      "0 255\n",
      "(3, 6, 192, 192)\n",
      "0.0 1.0\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAd8AAAKvCAYAAAArysUEAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3X+sHOV97/HPp3bDH5QrIGwty+AakJMoRO0JHLlBCRGEJjEIxSGRqK0qcVJ0D0hw1bSVekmRLqgSUpWGoFv1huQgLJyrxEDrkKDKbcNFuSGtoHBMfB1DINjECB859glEAYWIxPb3/nFmk+Gw692zM/vMj32/pNXuPjOz851zzvjj55nZGUeEAABAOr9VdQEAAEwawhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMTGFr62N9p+1vZ+2zeNaz0AADSNx/E9X9srJP1Q0gclHZL0hKQtEfF06SsDAKBhxtXz3SBpf0Q8HxG/lHSvpE1jWhcAAI2yckyfu0bSi7n3hyT9Yb+ZbXOZLUyyn0REp+oiynLWWWfFunXrqi4DqMTu3buH2p/HFb4D2Z6RNFPV+oEaeaHqAorK789r167V3NxcxRUB1bA91P48rmHneUnn5N6fnbX9WkTMRsR0REyPqQYAieT3506nNZ14YGzGFb5PSFpv+1zbb5G0WdKDY1oXAACNMpZh54g4ZvtGSf8maYWkbRHx1DjWBQBA04ztmG9E7JK0a1yfDwBAU3GFKwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXzRGRFRdAoCS+MtXVV1CpQhfNAoBDLTHJAcw4YvGIYCB9pjUACZ80UgEMNAekxjAI4ev7XNsf9v207afsv1nWfuttudt78keV5ZXLvAbBDDQHpMWwEV6vsck/WVEvFPSeyTdYPud2bQ7ImIqe+wqXCXQBwEMtMckBfDI4RsRhyPiyez1q5J+IGlNWYUBwyKAgfaYlAAu5Ziv7XWS3i3pP7OmG23vtb3N9hllrAM4GQIYaI9JCODC4Wv7dyTtlPSZiHhF0p2Szpc0JemwpNv7LDdje872XNEaAIkArlJ+f15YWKi6HLRA2wO4UPja/m0tBu9XI+LrkhQRRyLieESckHSXpA29lo2I2YiYjojpIjUAeQRwNfL7c6fTqboctESbA7jI2c6WdLekH0TEF3Ltq3OzXS1p3+jlActHAAPt0dYALtLzfa+kT0j6wJKvFX3O9vdt75V0maQ/L6NQYDkIYKA92hjAK0ddMCL+XZJ7TOKrRaiFiNDiAA2ApvOXr1Jc989Vl1EarnCFVqMHDLRHm3rAhC9ajwAG2qMtAUz4YiIQwEB7tCGARz7mC6TG8VugPdp0/HYU9HwBAEiM8AUAIDHCFwCAxAhfAAASI3wBAEiM8AUAIDHCFwCAxAhfAAASI3wBAEiM8AUAIDHCFwCAxLi2MwDgpC4+/7bSPuvRAzeX9llNVrjna/ug7e/b3mN7Lms70/ZDtp/Lns8oXiqapHsXobKeAaBNyhp2viwipiJiOnt/k6SHI2K9pIez95gg3TsQlfUMAG0yrmO+myRtz15vl/TRMa0HNUXPFwD6KyN8Q9K3bO+2PZO1rYqIw9nrH0taVcJ60CD0fAGgvzJOuHpfRMzb/l1JD9l+Jj8xIsL2m7ovWVDPLG1HO0SEbJf2jHrL789r166tuBqg/gr3fCNiPns+KukBSRskHbG9WpKy56M9lpuNiOnccWK0CD3fyZLfnzudTtXlALVXKHxtn2r7tO5rSR+StE/Sg5K2ZrNtlfTNIutB83DMFwD6KzrsvErSA1nvZKWkr0XEv9p+QtL9tq+V9IKkawquBw1DzxcA+isUvhHxvKQ/6NH+kqTLi3w2mo1jvgDQH5eXxFjQ8wWA/ghfjAXHfAGgP8IXY0HPFwD6I3wxFvR8AaA/whdjQc8XAPojfDEW9HwBoD/CF2NBzxcA+iN8MRb0fAGgP8IXY0HPFwD6K+OuRpUatWfEP+pA/Tz+3st+/frJiz/Wc55e7bMfXzO2miA9euDmqktoncaH76i6oU0IA/XSL3RPZmbnvCRCGM0x8cPOHFME6mOU4M3rhjBQd40O37KCkwAGqldWcBLAaILGDjsvJzDzQ8v9luPuOUB1ThaYFz769Te8v/7z/23gcjM75xmCRq01sudbpKdqu2/I0gMG0usXoBc++vU3Be9Ssx9f0zdk6QGjzhoXvmUFJAEMVO9kwbscBDCaZuTwtf1223tyj1dsf8b2rbbnc+1XlllwmRhmBupnucHbxTAzmmTk8I2IZyNiKiKmJF0k6TVJD2ST7+hOi4hdZRSarbOsj/q1XgFM77d6EcHvoeV69UpHDd6uXgFM77d62797sbZ/9+Kqy6iVsk64ulzSgYh4IXVvkt4r0B4b/uPbVZcAJFHWMd/Nknbk3t9oe6/tbbbPKGkdb1JW8BLgQPXKGjZm+BlNUDh8bb9F0kck/WPWdKek8yVNSTos6fY+y83YnrM9V7QGANXK788LCwtVlwPUXhk93yskPRkRRyQpIo5ExPGIOCHpLkkbei0UEbMRMR0R0yXUAKBC+f250+lUXQ5Qe2WE7xblhpxtr85Nu1rSvhLWAQBAaxQ64cr2qZI+KOm6XPPnbE9JCkkHl0wDAGDiFQrfiPi5pLcuaftEoYoAAGi5xl3hCgCApmvsjRWk8m6GwMUc0lrOz3uYefmqWDuUdTMELqqR1nIunjHMvFsvebRIOY1BzxcAgMRch16f7aGL6FVvkZ5P2Z+HcnR/LxPyu9jdpq/cTU9Px9zccF/f79VLLdL7LfvzUI5uj3cSerW2h9qfW9HzHfU/EHX4jweANxp12JjhZjRJ48K3rFsB9pt/QnpaQC2UdSvAfvPT60VdNS58pZMH8KAQPtk8BC+Q3skCeFAIn2weghd11tiznW33DdFRhpMJXqA6sx9f0zdERxlOJnhRd43s+XZxVyOgPbirESZJo8NXKh6cBC9QH0WDk+BFUzR22DmvG6DLGW4mdIF66gbocoabCV00TSvCt4tABdqDQEWbtSp80R78Rwpoj0m4uMZytS58Tzb0zD/oQLO8uqf/aSmnTZ1IWAlQrtaE7zDHeyfskoVAY50sdJfOQwijiRp/trNU3tWtAFRvmOAtMj9QB0P91dreZvuo7X25tjNtP2T7uez5jKzdtv/e9n7be21fOK7iAQBoomH/y3iPpI1L2m6S9HBErJf0cPZekq6QtD57zEi6s3iZ/XFTBaA9Ru3F0vtF0wz1FxsRj0h6eUnzJknbs9fbJX001/6VWPSYpNNtry6jWAAA2qDIfxdXRcTh7PWPJa3KXq+R9GJuvkNZGwAAUEknXMXiGO6yxnFtz9iesz3cXbcB1FZ+f15YWKi6HKD2ioTvke5wcvZ8NGufl3RObr6zs7Y3iIjZiJiOiOkCNQCogfz+3Ol0qi4HqL0i4fugpK3Z662Svplr/2R21vN7JP0sNzwNAMDEG+oiG7Z3SLpU0lm2D0m6RdLfSrrf9rWSXpB0TTb7LklXStov6TVJny65ZgAAGm2o8I2ILX0mXd5j3pB0Q5GilsP2SF8b4ipXQP2cNnVipK8NcZUrNA1fjgMAILFWhO9ye7H0eoH6Wm4vll4vmqg1N1boBip3NQKarxuo3NUIbdWa8O0iYIH2IGDRVq0YdgYAoEkIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACCxgeFre5vto7b35dr+zvYztvfafsD26Vn7Otu/sL0ne3xpnMUDANBEw/R875G0cUnbQ5LeFRG/L+mHkj6bm3YgIqayx/XllAkAQHsMDN+IeETSy0vavhURx7K3j0k6ewy1AQDQSmUc8/1TSf+Se3+u7e/Z/o7tS0r4fAAAWmVlkYVt3yzpmKSvZk2HJa2NiJdsXyTpG7YviIhXeiw7I2mmyPoB1EN+f167dm3F1QD1N3LP1/anJF0l6U8iIiQpIl6PiJey17slHZD0tl7LR8RsRExHxPSoNQCoh/z+3Ol0qi4HqL2Rwtf2Rkl/JekjEfFarr1je0X2+jxJ6yU9X0ahAAC0xcBhZ9s7JF0q6SzbhyTdosWzm0+R9JBtSXosO7P5/ZL+xvavJJ2QdH1EvNzzgwEAmFADwzcitvRovrvPvDsl7SxaFAAAbcYVrgAASIzwBQAgMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIbGL62t9k+antfru1W2/O292SPK3PTPmt7v+1nbX94XIUDANBUw/R875G0sUf7HRExlT12SZLtd0raLOmCbJkv2l5RVrEAALTBwPCNiEckvTzk522SdG9EvB4RP5K0X9KGAvUBANA6RY753mh7bzYsfUbWtkbSi7l5DmVtAAAgM2r43inpfElTkg5Lun25H2B7xvac7bkRawBQE/n9eWFhoepygNobKXwj4khEHI+IE5Lu0m+GluclnZOb9eysrddnzEbEdERMj1IDgPrI78+dTqfqcoDaGyl8ba/Ovb1aUvdM6AclbbZ9iu1zJa2X9HixEgEAaJeVg2awvUPSpZLOsn1I0i2SLrU9JSkkHZR0nSRFxFO275f0tKRjkm6IiOPjKR0AgGYaGL4RsaVH890nmf82SbcVKQoAgDbjClcAACRG+AIAkBjhCwBAYoQvAACJEb4AACRG+AIAkBjhCwBAYoQvAACJEb4AACRG+AIAkBjhCwBAYoQvAACJEb4AACRG+AIAkBjhCwBAYoQvAACJDQxf29tsH7W9L9d2n+092eOg7T1Z+zrbv8hN+9I4iwcAoIlWDjHPPZL+QdJXug0R8cfd17Zvl/Sz3PwHImKqrAIBAGibgeEbEY/YXtdrmm1LukbSB8otCwCA9ip6zPcSSUci4rlc27m2v2f7O7YvKfj5AAC0zjDDziezRdKO3PvDktZGxEu2L5L0DdsXRMQrSxe0PSNppuD6AdRAfn9eu3ZtxdUA9Tdyz9f2Skkfk3Rfty0iXo+Il7LXuyUdkPS2XstHxGxETEfE9Kg1AKiH/P7c6XSqLgeovSLDzn8k6ZmIONRtsN2xvSJ7fZ6k9ZKeL1YiAADtMsxXjXZIelTS220fsn1tNmmz3jjkLEnvl7Q3++rRP0m6PiJeLrNgAACabpiznbf0af9Uj7adknYWLwsAgPbiClcAACRG+AIAkBjhCwBAYoQvAACJEb4AACRG+AIAkBjhCwBAYoQvAACJEb4AACRG+AIAkBjhCwBAYo6IqmuQ7QVJP5f0k6prKcFZYjvqpAnb8XsR0Zr78Nl+VdKzVddRgib87QyD7UhrqP25FuErSbbn2nBvX7ajXtqyHU3Slp8521EvbdmOLoadAQBIjPAFACCxOoXvbNUFlITtqJe2bEeTtOVnznbUS1u2Q1KNjvkCADAp6tTzBQBgIhC+AAAkRvgCAJAY4QsAQGKELwAAiRG+AAAkRvgCAJAY4QsAQGKELwAAiRG+AAAkRvgCAJAY4QsAQGKELwAAiRG+AAAkRvgCAJAY4QsAQGKELwAAiRG+AAAkRvgCAJAY4QsAQGKELwAAiRG+AAAkRvgCAJAY4QsAQGKELwAAiRG+AAAkRvgCAJAY4QsAQGKELwAAiY0tfG1vtP2s7f22bxrXegAAaBpHRPkfaq+Q9ENJH5R0SNITkrZExNOlrwwAgIYZV893g6T9EfF8RPxS0r2SNo1pXQAANMq4wneNpBdz7w9lbQAATLyVVa3Y9oykmeztRVXVAdTATyKiU3URReT351NPPfWid7zjHRVXBFRj9+7dQ+3P4wrfeUnn5N6fnbX9WkTMSpqVJNvlH3gGmuOFqgsoKr8/T09Px9zcXMUVAdWwPdT+PK5h5yckrbd9ru23SNos6cExrQsAgEYZS883Io7ZvlHSv0laIWlbRDw1jnUBANA0YzvmGxG7JO0a1+cDANBUXOEKAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASIzwBQAgsZHD1/Y5tr9t+2nbT9n+s6z9VtvztvdkjyvLKxcAgOZbWWDZY5L+MiKetH2apN22H8qm3RERny9eHgAA7TNy+EbEYUmHs9ev2v6BpDVlFQYAQFuVcszX9jpJ75b0n1nTjbb32t5m+4wy1gEAQFsUDl/bvyNpp6TPRMQrku6UdL6kKS32jG/vs9yM7Tnbc0VrAFCt/P68sLBQdTlA7RUKX9u/rcXg/WpEfF2SIuJIRByPiBOS7pK0odeyETEbEdMRMV2kBgDVy+/PnU6n6nKA2itytrMl3S3pBxHxhVz76txsV0vaN3p5AAC0T5Gznd8r6ROSvm97T9b215K22J6SFJIOSrquUIUAALRMkbOd/12Se0zaNXo5AAC0H1e4AgAgMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASGxl0Q+wfVDSq5KOSzoWEdO2z5R0n6R1kg5KuiYiflp0XQAAtEFZPd/LImIqIqaz9zdJejgi1kt6OHsPAAA0vmHnTZK2Z6+3S/romNYDAEDjlBG+IelbtnfbnsnaVkXE4ez1jyWtKmE9AAC0QuFjvpLeFxHztn9X0kO2n8lPjIiwHUsXyoJ6Zmk7qhex+OuyXXElaIr8/rx27dqKq0He9u9eLEnaesmjFVeCvMI934iYz56PSnpA0gZJR2yvlqTs+WiP5WYjYjp3nBhAQ+X3506nU3U5QO0VCl/bp9o+rfta0ock7ZP0oKSt2WxbJX2zyHoAAGiTosPOqyQ9kA1PrpT0tYj4V9tPSLrf9rWSXpB0TcH1AADQGoXCNyKel/QHPdpfknR5kc8GAKCtuMIVAACJEb4AACRG+AIAkFgZ3/NFw3S/x1t0Pr4HDFSv+z3eovPxPeC06PkCAJAYPd8JNKjHyhWugOYY1GPlClf1RM8XAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASIzv+eJN+H4v0B58v7ee6PkCAJAY4QsAQGIjDzvbfruk+3JN50n6H5JOl/RfJS1k7X8dEbtGrhAAgJYZOXwj4llJU5Jke4WkeUkPSPq0pDsi4vOlVAgAQMuUNex8uaQDEfFCSZ8HAEBrlRW+myXtyL2/0fZe29tsn1HSOgAAaIXC4Wv7LZI+Iukfs6Y7JZ2vxSHpw5Ju77PcjO0523NFawBQrfz+vLCwMHgBYMKV0fO9QtKTEXFEkiLiSEQcj4gTku6StKHXQhExGxHTETFdQg0AKpTfnzudTtXlALVXRvhuUW7I2fbq3LSrJe0rYR0AALRGoStc2T5V0gclXZdr/pztKUkh6eCSaQAATLxC4RsRP5f01iVtnyhUEQAALccVrgAASIzwBQAgMcIXAIDECF8AABLjfr6YKBEx9Lzc1xiot8ffe9nQ8274j2+PsZLlo+cLAEBi9HwxEZbT4wVQbzM75yVJF178sYHzXvjo18ddzkgI34IiQrYHPqMahC6GdfH5tw0976MHbh5jJeinG7rL8WQW0D2vc1whhp0L6gbroGekVzR4CW6gPkYJ3jKXLxs934Lo+dZPmaHZ/Sx+h0A1ygzN7mfNfnxNaZ85KsK3IHq+9TKu3ir/iQLSGxS8Jzue++RJjgfP7JyvPIAJ34Lo+TZD93fAUDLQfN3gfPwL/efpBvPJQrhKHPMtiJ5vffQL1rJ+BwQ3kE6/Xu9ye6z9esdVHwMmfAvq/oM86BnjNe7gHbQeAOUpK3i76hjAhG9B9Hzrq+jPnt8dUB9Fj9FWfYx3qdYf8z1ZT6WMf1w55lu9Xr/jfj/z5f4uur/Dpevjd1qNV/f07y+cNnUiYSUYl1690X7BudxLRs5+fM2bPr+qk6+G6vna3mb7qO19ubYzbT9k+7ns+Yys3bb/3vZ+23ttXziu4gcZNERYxhAiPV8gjZMF7zDTgToZ9q/1Hkkbl7TdJOnhiFgv6eHsvSRdIWl99piRdGfxMpdv2GAt60IMHPOtj7L/w8N/oKo3bLASwO1Tdq+0LsPPQw07R8Qjttctad4k6dLs9XZJ/1fSf8/avxKLqfOY7dNtr46Iw2UUPIzlBl6RYUR6vsB4LTdQX93zWyMNQXPJSKRU5L+Jq3KB+mNJq7LXayS9mJvvUNYGAABU0tnOWS93Wd1N2zO252zPlVEDgOrk9+eFhYWqywFqr0j4HrG9WpKy56NZ+7ykc3LznZ21vUFEzEbEdERMF6gBQA3k9+dOp1N1OUDtFQnfByVtzV5vlfTNXPsns7Oe3yPpZymP9wIAUHdDnXBle4cWT646y/YhSbdI+ltJ99u+VtILkq7JZt8l6UpJ+yW9JunTJdcMAECjDXu285Y+ky7vMW9IuqFIUUARZV8Eg6+LAdUp+yIYVV/TuauVX4ob5SpGAOppuV8b4kpXaIJWhq80fKASvM3X63dYVm91OZeuxPgMG6gEb/P16uWW1VtdzqUrx6214SsN/keSf0Tbrayrl6EeBgUrwdtuRQO4LsPNXa0OX2kxYPs90B79fp+jBmiqWxRieU6bOtH3gfbo1xsdNUDLvkVhGVofvpgcZQUwwQtUr6wArmPwShNwS0FA+k2gnixAGWYGmqEbqCcL0LoNMy9F+KJVet1/N2/UgKXXC6TX6/67eaMGbNW9XonwRQt1g7LM+zUDqEY3KMvoydYhdLs45ovWKhqcBC9QH0WDs07BK9HzRcuN0gsmdIF6GqUXXLfQ7SJ8MREIVKA96hqoy8GwMwAAiRG+AAAkRvgCAJAY4QsAQGKELwAAiQ0MX9vbbB+1vS/X9ne2n7G91/YDtk/P2tfZ/oXtPdnjS+MsHgCAJhqm53uPpI1L2h6S9K6I+H1JP5T02dy0AxExlT2uL6dMAADaY2D4RsQjkl5e0vatiDiWvX1M0tljqA0AgFYq45jvn0r6l9z7c21/z/Z3bF9SwucDANAqha5wZftmScckfTVrOixpbUS8ZPsiSd+wfUFEvNJj2RlJM0XWD6Ae8vvz2rVrK64GqL+Re762PyXpKkl/EtmFcyPi9Yh4KXu9W9IBSW/rtXxEzEbEdERMj1oDgHrI78+dTqfqcoDaGyl8bW+U9FeSPhIRr+XaO7ZXZK/Pk7Re0vNlFAoAQFsMHHa2vUPSpZLOsn1I0i1aPLv5FEkPZResfyw7s/n9kv7G9q8knZB0fUS83PODAQCYUAPDNyK29Gi+u8+8OyXtLFpU1SKCu+AALeEvX6W47p+rLgN4A65w1cdy7v8KoN785auqLgF4A8L3JAhgoD0IYNQJ4TsAAQy0BwGMuiB8h0AAA+1BAKMOCN8hEcBAexDAqBrhuwwEMNAeBDCqRPguEwEMtAcBjKoQviMggIH2IIBRBcJ3RAQw0B4EMFIjfAsggIH2IICREuFbEAEMtAcBjFQI3xIQwEB7EMBIgfAtCQEMtAcBjHEbeFejScQdjYD24I5GqCN6vgAAJEb4AgCQ2MDwtb3N9lHb+3Jtt9qet70ne1yZm/ZZ2/ttP2v7w+MqHACAphqm53uPpI092u+IiKnssUuSbL9T0mZJF2TLfNH2irKKBQCgDQaGb0Q8IunlIT9vk6R7I+L1iPiRpP2SNhSoDwCA1ilyzPdG23uzYekzsrY1kl7MzXMoawMAAJlRw/dOSedLmpJ0WNLty/0A2zO252zPjVgDgJrI788LCwtVlwPU3kjhGxFHIuJ4RJyQdJd+M7Q8L+mc3KxnZ229PmM2IqYjYnqUGgDUR35/7nQ6VZcD1N5I4Wt7de7t1ZK6Z0I/KGmz7VNsnytpvaTHi5UIAEC7DLzCle0dki6VdJbtQ5JukXSp7SlJIemgpOskKSKesn2/pKclHZN0Q0QcH0/pAAA008DwjYgtPZrvPsn8t0m6rUhRAAC0GVe4AgAgMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASGxg+NreZvuo7X25tvts78keB23vydrX2f5FbtqXxlk8AABNtHKIee6R9A+SvtJtiIg/7r62fbukn+XmPxARU2UVCABA2wwM34h4xPa6XtNsW9I1kj5QblkAALRX0WO+l0g6EhHP5drOtf0929+xfUnBzwcAoHWGGXY+mS2SduTeH5a0NiJesn2RpG/YviAiXlm6oO0ZSTMF1w+gBvL789q1ayuuBqi/kXu+tldK+pik+7ptEfF6RLyUvd4t6YCkt/VaPiJmI2I6IqZHrQFAPeT3506nU3U5QO0VGXb+I0nPRMShboPtju0V2evzJK2X9HyxEgEAaJdhvmq0Q9Kjkt5u+5Dta7NJm/XGIWdJer+kvdlXj/5J0vUR8XKZBQMA0HTDnO28pU/7p3q07ZS0s3hZAAC0F1e4AgAgMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASMwRUXUNsr0g6eeSflJ1LSU4S2xHnTRhO34vIlpzE1zbr0p6tuo6StCEv51hsB1pDbU/1yJ8Jcn2XERMV11HUWxHvbRlO5qkLT9ztqNe2rIdXQw7AwCQGOELAEBidQrf2aoLKAnbUS9t2Y4macvPnO2ol7Zsh6QaHfMFAGBS1KnnCwDARCB8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEhtb+NreaPurKeZ9AAART0lEQVRZ2/tt3zSu9QAA0DSOiPI/1F4h6YeSPijpkKQnJG2JiKdLXxkAAA0zrp7vBkn7I+L5iPilpHslbRrTugAAaJSVY/rcNZJezL0/JOkP8zPYnpE0k729aEx1AE3wk4joVF1EEfn9+dRTT73oHe94R8UVAdXYvXv3UPvzuMJ3oIiYlTQrSbbLH/sGmuOFqgsoKr8/T09Px9zcXMUVAdWwPdT+PK5h53lJ5+Ten521AQAw8cYVvk9IWm/7XNtvkbRZ0oNjWhcAAI0ylmHniDhm+0ZJ/yZphaRtEfHUONYFAEDTjO2Yb0TskrRrXJ8PAEBTcYUrAAASI3wBAEiM8AUAIDHCFwCAxAhfAAASI3wBAEiM8AUAIDHCFwCAxAhfAAASI3wBAEiM8AUAIDHCFwCAxAhfAAASI3wBAEiM8AUAIDHCFwCAxEYOX9vn2P627adtP2X7z7L2W23P296TPa4sr1wAAJpvZYFlj0n6y4h40vZpknbbfiibdkdEfL54ecB4RUTfabYTVgKgqJmd832nzX58TcJKBhs5fCPisKTD2etXbf9AUr22DujjZKG7dB5CGKi3k4Xu0nnqEsJFer6/ZnudpHdL+k9J75V0o+1PSprTYu/4pz2WmZE0U8b6geVYGry9wjU/T0QQwAPk9+e1a9dWXA0mydLg7RWu+Xlmds7XIoALn3Bl+3ck7ZT0mYh4RdKdks6XNKXFnvHtvZaLiNmImI6I6aI1AKPqF6qE7fLk9+dOp1N1OZhQ/UK1DmG7VKHwtf3bWgzer0bE1yUpIo5ExPGIOCHpLkkbipcJlCPfox0UsPnpwwxTA0gr36MdFLD56cMMU49bkbOdLeluST+IiC/k2lfnZrta0r7RywPGY9ieLT1goP6G7dnWqQdc5JjveyV9QtL3be/J2v5a0hbbU5JC0kFJ1xWqEACAlilytvO/S+rVLdg1ejkAALQfV7gCACAxwhcAgMQIX0ykYc9e5ixnoP6GPXu5Dmc5dxG+mCjL+frQcr6WBCC95Xx9aDlfS0qB8MVE6xfA9HiB5ukXwHXq8XaVcnlJoElsv+nykYPmB1BPsx9f86bLRw6avw4IX0ykbqByVyOg+bqBOhF3NQLagIAF2qNuAXsyHPMFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIrPBFNmwflPSqpOOSjkXEtO0zJd0naZ2kg5KuiYifFl0XAABtUFbP97KImIqI6ez9TZIejoj1kh7O3gMAAI1v2HmTpO3Z6+2SPjqm9QAA0DhlhG9I+pbt3bZnsrZVEXE4e/1jSauWLmR7xvac7bkSagBQofz+vLCwUHU5QO2VcWOF90XEvO3flfSQ7WfyEyMibL/p1jERMStpVpJ6TW+7iPj1re2GeQbqLL8/T09PT9T+fPH5ty17mUcP3DyGStAkhXu+ETGfPR+V9ICkDZKO2F4tSdnz0aLraZtuoA77DABoj0Lha/tU26d1X0v6kKR9kh6UtDWbbaukbxZZTxt17yM77DMAoD2KDjuvkvRA1jtbKelrEfGvtp+QdL/tayW9IOmagutpHXq+ADC5CoVvRDwv6Q96tL8k6fIin912HPMFgMnFFa4qQs8XACYX4VsRjvkCwOQifCtCzxcAJhfhWxF6vgAwuQjfitDzBYDJRfhWhJ4vAEwuwrci9HwBYHIRvhWh5wsAk6uMGytgBPR8F3ERETQdN0n4DX/5KsV1/1x1GY1AzxeVo3cPtIe/fFXVJTQC4YtaIICB9iCAByN8URsEMNAeBPDJEb6oFQIYaA8CuD/CF7VDAAPtQQD3RviilghgoD0I4DcjfFFbBDDQHgTwG40cvrbfbntP7vGK7c/YvtX2fK79yjILxmQhgIH2IIB/Y+TwjYhnI2IqIqYkXSTpNUkPZJPv6E6LiF1lFIrJRQAD7UEALypr2PlySQci4oWSPg94AwIYaA8CuLzw3SxpR+79jbb32t5m+4xeC9iesT1ne66kGtByBHB95ffnhYWFqstBA0x6ABcOX9tvkfQRSf+YNd0p6XxJU5IOS7q913IRMRsR0xExXbQGTA4CuJ7y+3On06m6HDTEJAdwGT3fKyQ9GRFHJCkijkTE8Yg4IekuSRtKWAfwawQw0B6TGsBlhO8W5Yacba/OTbta0r4S1gG8AQEMtMckBnCh8LV9qqQPSvp6rvlztr9ve6+kyyT9eZF1AP0QwEB7TFoAF7qfb0T8XNJbl7R9olBFwDJwP2CgPSbpfsCFwhcoiuAE2mNSgrMMXF4SAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASGyo+/na3ibpKklHI+JdWduZku6TtE7SQUnXRMRPvXiD1v8p6UpJr0n6VEQ8WX7paUXEspfhXrVAPT3+3suWvcyG//j2GCrBpBq253uPpI1L2m6S9HBErJf0cPZekq6QtD57zEi6s3iZAAC0x1DhGxGPSHp5SfMmSduz19slfTTX/pVY9Jik022vLqNYAADaoMgx31URcTh7/WNJq7LXayS9mJvvUNb2BrZnbM/ZnitQA4AayO/PCwsLVZcD1F4pJ1zF4gHRZR0UjYjZiJiOiOkyagBQnfz+3Ol0qi4HqL0i4XukO5ycPR/N2uclnZOb7+ysDQAAqFj4Pihpa/Z6q6Rv5to/6UXvkfSz3PA0AAATb9ivGu2QdKmks2wfknSLpL+VdL/tayW9IOmabPZdWvya0X4tftXo0yXXDABAow0VvhGxpc+ky3vMG5JuKFIUAABtxhWuAABIjPAFACAxwhcAgMQIXwAAEhvqhCtwkwSgTbhJAqpGzxcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIbGD42t5m+6jtfbm2v7P9jO29th+wfXrWvs72L2zvyR5fGmfxAAA00TA933skbVzS9pCkd0XE70v6oaTP5qYdiIip7HF9OWUCANAeA8M3Ih6R9PKStm9FxLHs7WOSzh5DbQAAtFIZx3z/VNK/5N6fa/t7tr9j+5J+C9mesT1ne66EGgBUKL8/LywsVF0OUHuFwtf2zZKOSfpq1nRY0tqIeLekv5D0Ndv/pdeyETEbEdMRMV2kBgDVy+/PnU6n6nKA2hs5fG1/StJVkv4kIkKSIuL1iHgpe71b0gFJbyuhTgAAWmOk8LW9UdJfSfpIRLyWa+/YXpG9Pk/SeknPl1EoAABtsXLQDLZ3SLpU0lm2D0m6RYtnN58i6SHbkvRYdmbz+yX9je1fSToh6fqIeLnnBwMAMKEGhm9EbOnRfHefeXdK2lm0KAAA2owrXAEAkBjhCwBAYoQvAACJEb4AACRG+AIAkBjhCwBAYoQvAACJEb4AACRG+AIAkBjhCwBAYoQvAACJEb4AACRG+AIAkBjhCwBAYoQvAACJEb4AACQ2MHxtb7N91Pa+XNuttudt78keV+amfdb2ftvP2v7wuAoHAKCphun53iNpY4/2OyJiKnvskiTb75S0WdIF2TJftL2irGIBAGiDlYNmiIhHbK8b8vM2Sbo3Il6X9CPb+yVtkPToyBWOQUT0nWY7YSUAinp1T/8+xGlTJxJWAgyvyDHfG23vzYalz8ja1kh6MTfPoaztTWzP2J6zPVeghmWJiJMG77DzAHij/P68sLCQZJ2v7vmtkwbvsPMAVRj1r/JOSedLmpJ0WNLty/2AiJiNiOmImB6xhuWub6zzA5Msvz93Op2xr2+5gUoAo25G+ouMiCMRcTwiTki6S4tDy5I0L+mc3KxnZ20AACAzUvjaXp17e7Wk7pnQD0rabPsU2+dKWi/p8WIlFjdqL5beL1A/o/Zi6f2iTgaecGV7h6RLJZ1l+5CkWyRdantKUkg6KOk6SYqIp2zfL+lpScck3RARx8dTOgAAzTTM2c5bejTffZL5b5N0W5GiAABoM8ZhAABIjPAFACAxwhcAgMQIXwAAEpuI8B31kpFcanJ0XCkM4zLqJSO51OTotn/3Ym3/7sVVl9EqExG+AADUycSE73J7sfR6gfpabi+WXi/qZuD3fNukG6jc1Qhovm6gclcjNNFEhW8XAQu0BwGLJpqYYWcAAOqC8AUAIDHCFwCAxCbymC+KG/Y7vIPm4/g7UL1hv8M7aL6tlzxaRjkTgZ4vAACJ0fPFSAb1WLs9Xnq2QP0N6rF2e7z0bMtDzxcAgMQGhq/tbbaP2t6Xa7vP9p7scdD2nqx9ne1f5KZ9aZzFAwDQRMMMO98j6R8kfaXbEBF/3H1t+3ZJP8vNfyAipsoqEACAthkYvhHxiO11vaZ58YDeNZI+UG5ZAAC0V9FjvpdIOhIRz+XazrX9PdvfsX1JvwVtz9iesz1XsAYAFcvvzwsLC1WXA9Re0fDdImlH7v1hSWsj4t2S/kLS12z/l14LRsRsRExHxHTBGgBULL8/dzqdqssBam/k8LW9UtLHJN3XbYuI1yPipez1bkkHJL2taJEAALRJke/5/pGkZyLiULfBdkfSyxFx3PZ5ktZLer5gjWggvt8LtAff7y3fMF812iHpUUlvt33I9rXZpM1645CzJL1f0t7sq0f/JOn6iHi5zIIBAGi6Yc523tKn/VM92nZK2lm8LAAA2osrXAEAkBjhCwBAYoQvAACJEb4AACRG+AIAkBjhCwBAYoQvAACJEb4AACRG+AIAkBjhCwBAYoQvAACJOSKqrkG2FyT9XNJPqq6lBGeJ7aiTJmzH70VEa26Ca/tVSc9WXUcJmvC3Mwy2I62h9udahK8k2Z6LiOmq6yiK7aiXtmxHk7TlZ8521EtbtqOLYWcAABIjfAEASKxO4TtbdQElYTvqpS3b0SRt+ZmzHfXSlu2QVKNjvgAATIo69XwBAJgIlYev7Y22n7W93/ZNVdezHLYP2v6+7T2257K2M20/ZPu57PmMqutcyvY220dt78u19azbi/4++/3stX1hdZW/UZ/tuNX2fPY72WP7yty0z2bb8aztD1dTdbuxP6fH/tzM/bnS8LW9QtL/knSFpHdK2mL7nVXWNILLImIqdwr8TZIejoj1kh7O3tfNPZI2LmnrV/cVktZnjxlJdyaqcRj36M3bIUl3ZL+TqYjYJUnZ39VmSRdky3wx+/tDSdifK3OP2J8btz9X3fPdIGl/RDwfEb+UdK+kTRXXVNQmSduz19slfbTCWnqKiEckvbykuV/dmyR9JRY9Jul026vTVHpyfbajn02S7o2I1yPiR5L2a/HvD+Vhf64A+3Mz9+eqw3eNpBdz7w9lbU0Rkr5le7ftmaxtVUQczl7/WNKqakpbtn51N/F3dGM2pLYtN0zYxO1omqb/jNmf66mV+3PV4dt074uIC7U4lHOD7ffnJ8biqeSNO528qXVn7pR0vqQpSYcl3V5tOWgQ9uf6ae3+XHX4zks6J/f+7KytESJiPns+KukBLQ57HOkO42TPR6urcFn61d2o31FEHImI4xFxQtJd+s1QVKO2o6Ea/TNmf66fNu/PVYfvE5LW2z7X9lu0eAD9wYprGortU22f1n0t6UOS9mmx/q3ZbFslfbOaCpetX90PSvpkdpbkeyT9LDecVTtLjl9drcXfibS4HZttn2L7XC2ecPJ46vpajv25Ptif6y4iKn1IulLSDyUdkHRz1fUso+7zJP2/7PFUt3ZJb9Xi2YXPSfo/ks6sutYete/Q4hDOr7R4rOTafnVLshbPYD0g6fuSpquuf8B2/O+szr1a3EFX5+a/OduOZyVdUXX9bXywP1dSO/tzA/dnrnAFAEBiVQ87AwAwcQhfAAASI3wBAEiM8AUAIDHCFwCAxAhfAAASI3wBAEiM8AUAILH/D3LKQakHH5b0AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 576x864 with 6 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import os,sys\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import helper\n",
    "import simulation\n",
    "\n",
    "# Generate some random images\n",
    "input_images, target_masks = simulation.generate_random_data(192, 192, count=3)\n",
    "\n",
    "for x in [input_images, target_masks]:\n",
    "    print(x.shape)\n",
    "    print(x.min(), x.max())\n",
    "\n",
    "# Change channel-order and make 3 channels for matplot\n",
    "input_images_rgb = [x.astype(np.uint8) for x in input_images]\n",
    "\n",
    "# Map each channel (i.e. class) to each color\n",
    "target_masks_rgb = [helper.masks_to_colorimg(x) for x in target_masks]\n",
    "\n",
    "# Left: Input image, Right: Target mask (Ground-truth)\n",
    "helper.plot_side_by_side([input_images_rgb, target_masks_rgb])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'train': 2000, 'val': 200}"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torchvision import transforms, datasets, models\n",
    "\n",
    "class SimDataset(Dataset):\n",
    "    def __init__(self, count, transform=None):\n",
    "        self.input_images, self.target_masks = simulation.generate_random_data(192, 192, count=count)        \n",
    "        self.transform = transform\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.input_images)\n",
    "    \n",
    "    def __getitem__(self, idx):        \n",
    "        image = self.input_images[idx]\n",
    "        mask = self.target_masks[idx]\n",
    "        if self.transform:\n",
    "            image = self.transform(image)\n",
    "        \n",
    "        return [image, mask]\n",
    "\n",
    "# use same transform for train/val for this example\n",
    "trans = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "])\n",
    "\n",
    "train_set = SimDataset(2000, transform = trans)\n",
    "val_set = SimDataset(200, transform = trans)\n",
    "\n",
    "image_datasets = {\n",
    "    'train': train_set, 'val': val_set\n",
    "}\n",
    "\n",
    "batch_size = 25\n",
    "\n",
    "dataloaders = {\n",
    "    'train': DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0),\n",
    "    'val': DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=0)\n",
    "}\n",
    "\n",
    "dataset_sizes = {\n",
    "    x: len(image_datasets[x]) for x in image_datasets.keys()\n",
    "}\n",
    "\n",
    "dataset_sizes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([25, 3, 192, 192]) torch.Size([25, 6, 192, 192])\n",
      "0.0 1.0 0.02312283 0.1502936\n",
      "0.0 1.0 0.004655129 0.06806962\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7fe3ba443da0>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQUAAAD8CAYAAAB+fLH0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADn9JREFUeJzt3X+sZGV9x/H3p1j9w5oA1W4IYEGzmohptkrQpGqwrYqkcaV/0CVNpWq6mEDSP5o0aJNq2jRpWqmJqT+ypgRMFKStCDFUpaTRf0plqQQFRRaEsJt1qdioraa68O0fc67Oc7337tyZM7/fr2QyZ545M+c5e3c+93nOmXu+qSokacMvzLsDkhaLoSCpYShIahgKkhqGgqSGoSCpMbVQSHJJkoeSHEly7bS2I6lfmcb3FJKcBnwTeANwFLgHuKKqHux9Y5J6Na2RwkXAkap6tKp+DNwM7J/StiT16FlTet+zgSeGHh8FXrXdykn8WqU0fd+pqhecaqVphcIpJTkIHJzX9qU19PgoK00rFI4B5w49Pqdr+6mqOgQcAkcK0iKZ1jGFe4C9Sc5P8mzgAHD7lLYlqUdTGSlU1ckk1wCfB04Drq+qB6axLUn9msopyV13wumDNAv3VtWFp1rJbzRKahgKkhqGgqSGoSCpYShIahgKkhqGgqSGoSCpYShIahgKkhqGgqSGoSCpYShIahgKkhqGgqSGoSCpYShIahgKkhqGgqSGoSCpMXYoJDk3yb8leTDJA0n+uGt/X5JjSe7rbpf2111J0zbJJd5PAn9SVf+Z5HnAvUnu7J77QFW9f/LuSZq1sUOhqo4Dx7vlHyT5OoMakpKWWC/HFJKcB/w68B9d0zVJ7k9yfZIztnnNwSSHkxzuow+S+jFxMZgkvwR8Efirqvp0kj3Ad4AC/hI4q6recYr3sBiMNH3TLwaT5BeBfwY+UVWfBqiqE1X1dFU9A3wMuGiSbUiarUnOPgT4B+DrVfV3Q+1nDa12GfC18bsnadYmOfvwG8AfAF9Ncl/X9h7giiT7GEwfHgOumqiHkmbKArPS+rDArKTdMxQkNQwFSQ1DQVLDUBiyCAddpXkzFDYxGLTuDIUtGAxaZ4bCNqrKcNBaMhROwWDQujEURmAwaJ0YCiMyGLQuDAVJDUNhFxwtaB0YCrtkMGjVGQpjMBi0ygyFMfk9Bq0qQ2FCBoNWjaHQA4NBq8RQ6InBoFUxyYVbAUjyGPAD4GngZFVdmORM4FPAeQwu3np5Vf33pNuatsEFqqX11tdI4fVVtW/oopDXAndV1V7gru6xpCUwrenDfuDGbvlG4K1T2o6knvURCgV8Icm9SQ52bXu6ArQA3wb29LAdSTMw8TEF4DVVdSzJrwB3JvnG8JNVVVvVdegC5ODmdknzNfFIoaqOdfdPArcyqB15YqN8XHf/5BavO1RVF45SnELS7ExaYPa5SZ63sQy8kUHtyNuBK7vVrgRum2Q7kmZn0unDHuDW7lTes4BPVtXnktwD3JLkncDjwOUTbkfSjFhLUlof1pKUtHuGgqSGoSCpYShIahgKkhqGgqSGoSCp0cffPmgJjPp9FK8pIUcKkhqGgqSGoSCpYShIahgKkhqGgqSGoSCpYShIahgKkhqGgqSGX3NeE359WaMaOxSSvJRBvcgNLwL+HDgd+CPgv7r291TVHWP3UNJM9XLh1iSnAceAVwFvB/6nqt6/i9d74VZp+mZ64dbfAh6pqsd7ej9Jc9JXKBwAbhp6fE2S+5Ncn+SMnrahORgeSW4sb3W/CKUC1I+JQyHJs4G3AP/YNX0EeDGwDzgOXLfN6w4mOZzk8KR90PQMH6DcWN7q3gOZq2PiYwpJ9gNXV9Ubt3juPOCzVfXyU7yHv2YWVFX99AO/sbzVPXiGYwnM7JjCFQxNHTYKy3YuY1BbUtKSmOh7Cl1R2TcAVw01/02SfUABj216Tktm1OmDVoe1JLUjpw8rZaTpg99o1I4cKawf//ZBO/KU5PpxpKAdOVJYP44UtCNHCuvHkYJ25Ehh/ThSkNQwFLQjpw/rx+mDduT0Yf04UpDUMBQkNQwFSQ1DQVLDUJDUMBQkNQwFSQ1DQVLDUJDUMBQkNQwFSQ1DQVJjpFDoKj09meRrQ21nJrkzycPd/Rlde5J8MMmRrkrUK6bVeUn9G3WkcANwyaa2a4G7qmovcFf3GODNwN7udpBBxShJS2KkUKiqLwHf3dS8H7ixW74ReOtQ+8dr4G7g9E0FYiQtsEmOKeypquPd8reBPd3y2cATQ+sd7doa1pKUFlMvF1mpqtptQZeqOgQcAovBSItkkpHCiY1pQXf/ZNd+DDh3aL1zujZJS2CSULgduLJbvhK4baj9bd1ZiFcD3xuaZkhadBsX3dzpxqCq9HHgJwyOEbwT+GUGZx0eBv4VOLNbN8CHgEeArwIXjvD+5c2bt6nfDo/yebfArLQ+Riow6zcaJTUMBUkNQ0FSw1CQ1DAUJDUMBUkNQ0FSw1CQ1DAUJDUMBUkNQ0FSw1CQ1DAUJDUMBUkNQ0FSw1CQ1DAUJDUMBUkNQ0FS45ShsE0dyb9N8o2uVuStSU7v2s9L8qMk93W3j06z85L6N8pI4QZ+vo7kncDLq+rXgG8C7x567pGq2tfd3tVPNyXNyilDYas6klX1hao62T28m0HBF0kroI9jCu8A/mXo8flJvpLki0leu92LrCUpLaaJakkm+TPgJPCJruk48MKqeirJK4HPJLmgqr6/+bXWkpQW09gjhSR/CPwO8Pu1Ueap6v+q6qlu+V4GVaJe0kM/Jc3IWKGQ5BLgT4G3VNUPh9pfkOS0bvlFwF7g0T46Kmk2Tjl9SHITcDHw/CRHgfcyONvwHODOJAB3d2caXgf8RZKfAM8A76qq7275xpIWkrUkpfVhLUlJu2coSGpMdEpyUY0zJeqOjUhrb6VCYZLjIxuvXfZwWJX90PyszPShrwOmi3DgVZqnpR8pTONDPPye/sbVulnqUBglEE71oT7Ve1SVwaC1sjLTB0n9WOqRwnZ285t9Y12PJUgDSzlSqKptP8TjDvV3ep2BoXWydKGw0wd00rm/wSAtYShsp6+DgR5U1LpbqlDoe8qwne3ez9GC1sFKHmhcRbsNpFHXd2SkzZZqpLCVaf2nTuIHRmvJkcIuzes38Kjv598+aFJLP1KQ1C9DYUo8KKllZShM0U5fspIW1bi1JN+X5NhQzchLh557d5IjSR5K8qZpdXyZGAxaJuPWkgT4wFDNyDsAkrwMOABc0L3mwxuXfF93BoOWxVi1JHewH7i5KwrzLeAIcNEE/ZM0Y5McU7imK0V/fZIzurazgSeG1jnatU3NtH4DezxA62rcUPgI8GJgH4P6kdft9g0sMCstprFCoapOVNXTVfUM8DF+NkU4Bpw7tOo5XdtW73Goqi4cpTiFpNkZt5bkWUMPLwM2zkzcDhxI8pwk5zOoJfnlybrYbHfL9r6H+U4btM7GrSV5cZJ9QAGPAVcBVNUDSW4BHmRQov7qqnp6Ol1v9XUtxWUPBL/erEktZS3JaV1oZdr/Fn5gNWerW0typ79gHPeDvQjhKC2Clfwryd3UbTAMpNZSh0KSkeo2THsbo76PtAyWOhRgOpdo3/wBnjQYDAQtk6U8piBpelYmFKZ9Nedx3t9LumkZrUwowGQfwlFeO07lKWnZLP0xha1M8wPph12rbqVGCpImZyhIahgKkhqGgqSGoSCpYShIahgKkhqGgqSGoSCpYShIahgKkhrj1pL81FAdyceS3Ne1n5fkR0PPfXSanZfUv1H+IOoG4O+Bj280VNXvbSwnuQ743tD6j1TVvr46KGm2ThkKVfWlJOdt9VwGfzJ4OfCb/XZL0rxMekzhtcCJqnp4qO38JF9J8sUkr53w/SXN2KTXU7gCuGno8XHghVX1VJJXAp9JckFVfX/zC5McBA5OuH1JPRt7pJDkWcDvAp/aaOtK0D/VLd8LPAK8ZKvXW0tSWkyTTB9+G/hGVR3daEjygiSndcsvYlBL8tHJuihplkY5JXkT8O/AS5McTfLO7qkDtFMHgNcB93enKP8JeFdVfbfPDkuarqWsJSlpLKtbS1LS9BgKkhqGgqSGoSCpYShIahgKkhqGgqSGoSCpYShIahgKkhqGgqSGoSCpYShIahgKkhqTXo6tL98B/re7X2XPZ7X3cdX3D5Z7H391lJUW4noKAEkOr/ql2VZ9H1d9/2A99tHpg6SGoSCpsUihcGjeHZiBVd/HVd8/WIN9XJhjCpIWwyKNFCQtgLmHQpJLkjyU5EiSa+fdn7501bi/2lXfPty1nZnkziQPd/dnzLufu7FNBfIt9ykDH+x+rvcnecX8ej6abfbvfUmODVVSv3TouXd3+/dQkjfNp9f9m2sodIVjPgS8GXgZcEWSl82zTz17fVXtGzqFdS1wV1XtBe7qHi+TG4BLNrVtt09vZlAMaC+D8oAfmVEfJ3EDP79/AB/ofo77quoOgO7/6QHggu41H94ohLTs5j1SuAg4UlWPVtWPgZuB/XPu0zTtB27slm8E3jrHvuxaVX0J2FzcZ7t92g98vAbuBk5PctZsejqebfZvO/uBm7tSid8CjjD4/7z05h0KZwNPDD0+2rWtggK+kOTerpguwJ6qOt4tfxvYM5+u9Wq7fVqln+013RTo+qEp3yrtX2PeobDKXlNVr2AwjL46yeuGn6zBaZ+VOvWzivvEYNrzYmAfg6rq1823O9M371A4Bpw79Picrm3pVdWx7v5J4FYGQ8sTG0Po7v7J+fWwN9vt00r8bKvqRFU9XVXPAB/jZ1OEldi/rcw7FO4B9iY5P8mzGRy4uX3OfZpYkucmed7GMvBG4GsM9u3KbrUrgdvm08NebbdPtwNv685CvBr43tA0Y2lsOg5yGYOfIwz270CS5yQ5n8EB1S/Pun/TMNe/kqyqk0muAT4PnAZcX1UPzLNPPdkD3JoEBv/Gn6yqzyW5B7ilq9z9OHD5HPu4a10F8ouB5yc5CrwX+Gu23qc7gEsZHID7IfD2mXd4l7bZv4uT7GMwLXoMuAqgqh5IcgvwIHASuLqqnp5Hv/vmNxolNeY9fZC0YAwFSQ1DQVLDUJDUMBQkNQwFSQ1DQVLDUJDU+H/1Y/rQJgrTIwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torchvision.utils\n",
    "\n",
    "def reverse_transform(inp):\n",
    "    inp = inp.numpy().transpose((1, 2, 0))\n",
    "    inp = np.clip(inp, 0, 1)\n",
    "    inp = (inp * 255).astype(np.uint8)\n",
    "    \n",
    "    return inp\n",
    "\n",
    "# Get a batch of training data\n",
    "inputs, masks = next(iter(dataloaders['train']))\n",
    "\n",
    "print(inputs.shape, masks.shape)\n",
    "for x in [inputs.numpy(), masks.numpy()]:\n",
    "    print(x.min(), x.max(), x.mean(), x.std())\n",
    "\n",
    "plt.imshow(reverse_transform(inputs[3]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------------------------------\n",
      "        Layer (type)               Output Shape         Param #\n",
      "================================================================\n",
      "            Conv2d-1         [-1, 64, 224, 224]           1,792\n",
      "              ReLU-2         [-1, 64, 224, 224]               0\n",
      "            Conv2d-3         [-1, 64, 224, 224]          36,928\n",
      "              ReLU-4         [-1, 64, 224, 224]               0\n",
      "         MaxPool2d-5         [-1, 64, 112, 112]               0\n",
      "            Conv2d-6        [-1, 128, 112, 112]          73,856\n",
      "              ReLU-7        [-1, 128, 112, 112]               0\n",
      "            Conv2d-8        [-1, 128, 112, 112]         147,584\n",
      "              ReLU-9        [-1, 128, 112, 112]               0\n",
      "        MaxPool2d-10          [-1, 128, 56, 56]               0\n",
      "           Conv2d-11          [-1, 256, 56, 56]         295,168\n",
      "             ReLU-12          [-1, 256, 56, 56]               0\n",
      "           Conv2d-13          [-1, 256, 56, 56]         590,080\n",
      "             ReLU-14          [-1, 256, 56, 56]               0\n",
      "        MaxPool2d-15          [-1, 256, 28, 28]               0\n",
      "           Conv2d-16          [-1, 512, 28, 28]       1,180,160\n",
      "             ReLU-17          [-1, 512, 28, 28]               0\n",
      "           Conv2d-18          [-1, 512, 28, 28]       2,359,808\n",
      "             ReLU-19          [-1, 512, 28, 28]               0\n",
      "         Upsample-20          [-1, 512, 56, 56]               0\n",
      "           Conv2d-21          [-1, 256, 56, 56]       1,769,728\n",
      "             ReLU-22          [-1, 256, 56, 56]               0\n",
      "           Conv2d-23          [-1, 256, 56, 56]         590,080\n",
      "             ReLU-24          [-1, 256, 56, 56]               0\n",
      "         Upsample-25        [-1, 256, 112, 112]               0\n",
      "           Conv2d-26        [-1, 128, 112, 112]         442,496\n",
      "             ReLU-27        [-1, 128, 112, 112]               0\n",
      "           Conv2d-28        [-1, 128, 112, 112]         147,584\n",
      "             ReLU-29        [-1, 128, 112, 112]               0\n",
      "         Upsample-30        [-1, 128, 224, 224]               0\n",
      "           Conv2d-31         [-1, 64, 224, 224]         110,656\n",
      "             ReLU-32         [-1, 64, 224, 224]               0\n",
      "           Conv2d-33         [-1, 64, 224, 224]          36,928\n",
      "             ReLU-34         [-1, 64, 224, 224]               0\n",
      "           Conv2d-35          [-1, 6, 224, 224]             390\n",
      "================================================================\n",
      "Total params: 7,783,238\n",
      "Trainable params: 7,783,238\n",
      "Non-trainable params: 0\n",
      "----------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "from torchsummary import summary\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import pytorch_unet\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "model = pytorch_unet.UNet(6)\n",
    "model = model.to(device)\n",
    "\n",
    "summary(model, input_size=(3, 224, 224))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "import torch.nn.functional as F\n",
    "from loss import dice_loss\n",
    "\n",
    "def calc_loss(pred, target, metrics, bce_weight=0.5):\n",
    "    bce = F.binary_cross_entropy_with_logits(pred, target)\n",
    "        \n",
    "    pred = F.sigmoid(pred)\n",
    "    dice = dice_loss(pred, target)\n",
    "    \n",
    "    loss = bce * bce_weight + dice * (1 - bce_weight)\n",
    "    \n",
    "    metrics['bce'] += bce.data.cpu().numpy() * target.size(0)\n",
    "    metrics['dice'] += dice.data.cpu().numpy() * target.size(0)\n",
    "    metrics['loss'] += loss.data.cpu().numpy() * target.size(0)\n",
    "    \n",
    "    return loss\n",
    "\n",
    "def print_metrics(metrics, epoch_samples, phase):    \n",
    "    outputs = []\n",
    "    for k in metrics.keys():\n",
    "        outputs.append(\"{}: {:4f}\".format(k, metrics[k] / epoch_samples))\n",
    "        \n",
    "    print(\"{}: {}\".format(phase, \", \".join(outputs)))    \n",
    "\n",
    "def train_model(model, optimizer, scheduler, num_epochs=25):\n",
    "    best_model_wts = copy.deepcopy(model.state_dict())\n",
    "    best_loss = 1e10\n",
    "\n",
    "    for epoch in range(num_epochs):\n",
    "        print('Epoch {}/{}'.format(epoch, num_epochs - 1))\n",
    "        print('-' * 10)\n",
    "        \n",
    "        since = time.time()\n",
    "\n",
    "        # Each epoch has a training and validation phase\n",
    "        for phase in ['train', 'val']:\n",
    "            if phase == 'train':\n",
    "                scheduler.step()\n",
    "                for param_group in optimizer.param_groups:\n",
    "                    print(\"LR\", param_group['lr'])\n",
    "                    \n",
    "                model.train()  # Set model to training mode\n",
    "            else:\n",
    "                model.eval()   # Set model to evaluate mode\n",
    "\n",
    "            metrics = defaultdict(float)\n",
    "            epoch_samples = 0\n",
    "            \n",
    "            for inputs, labels in dataloaders[phase]:\n",
    "                inputs = inputs.to(device)\n",
    "                labels = labels.to(device)             \n",
    "\n",
    "                # zero the parameter gradients\n",
    "                optimizer.zero_grad()\n",
    "\n",
    "                # forward\n",
    "                # track history if only in train\n",
    "                with torch.set_grad_enabled(phase == 'train'):\n",
    "                    outputs = model(inputs)\n",
    "                    loss = calc_loss(outputs, labels, metrics)\n",
    "\n",
    "                    # backward + optimize only if in training phase\n",
    "                    if phase == 'train':\n",
    "                        loss.backward()\n",
    "                        optimizer.step()\n",
    "\n",
    "                # statistics\n",
    "                epoch_samples += inputs.size(0)\n",
    "\n",
    "            print_metrics(metrics, epoch_samples, phase)\n",
    "            epoch_loss = metrics['loss'] / epoch_samples\n",
    "\n",
    "            # deep copy the model\n",
    "            if phase == 'val' and epoch_loss < best_loss:\n",
    "                print(\"saving best model\")\n",
    "                best_loss = epoch_loss\n",
    "                best_model_wts = copy.deepcopy(model.state_dict())\n",
    "\n",
    "        time_elapsed = time.time() - since\n",
    "        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))\n",
    "    print('Best val loss: {:4f}'.format(best_loss))\n",
    "\n",
    "    # load best model weights\n",
    "    model.load_state_dict(best_model_wts)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda:0\n",
      "Epoch 0/39\n",
      "----------\n",
      "LR 0.0001\n",
      "train: bce: 0.210124, dice: 0.994346, loss: 0.602235\n",
      "val: bce: 0.030143, dice: 0.986439, loss: 0.508291\n",
      "saving best model\n",
      "0m 43s\n",
      "Epoch 1/39\n",
      "----------\n",
      "LR 0.0001\n",
      "train: bce: 0.022030, dice: 0.806168, loss: 0.414099\n",
      "val: bce: 0.023499, dice: 0.671528, loss: 0.347514\n",
      "saving best model\n",
      "0m 43s\n",
      "Epoch 2/39\n",
      "----------\n",
      "LR 0.0001\n",
      "train: bce: 0.023134, dice: 0.522101, loss: 0.272618\n",
      "val: bce: 0.017994, dice: 0.439513, loss: 0.228753\n",
      "saving best model\n",
      "0m 43s\n",
      "Epoch 3/39\n",
      "----------\n",
      "LR 0.0001\n",
      "train: bce: 0.015791, dice: 0.392756, loss: 0.204273\n",
      "val: bce: 0.015154, dice: 0.353304, loss: 0.184229\n",
      "saving best model\n",
      "0m 43s\n",
      "Epoch 4/39\n",
      "----------\n",
      "LR 0.0001\n",
      "train: bce: 0.012854, dice: 0.299000, loss: 0.155927\n",
      "val: bce: 0.011838, dice: 0.235490, loss: 0.123664\n",
      "saving best model\n",
      "0m 43s\n",
      "Epoch 5/39\n",
      "----------\n",
      "LR 0.0001\n",
      "train: bce: 0.010764, dice: 0.217516, loss: 0.114140\n",
      "val: bce: 0.010928, dice: 0.202027, loss: 0.106478\n",
      "saving best model\n",
      "0m 43s\n",
      "Epoch 6/39\n",
      "----------\n",
      "LR 0.0001\n",
      "train: bce: 0.010902, dice: 0.222725, loss: 0.116813\n",
      "val: bce: 0.010661, dice: 0.192998, loss: 0.101830\n",
      "saving best model\n",
      "0m 43s\n",
      "Epoch 7/39\n",
      "----------\n",
      "LR 0.0001\n",
      "train: bce: 0.009604, dice: 0.184641, loss: 0.097122\n",
      "val: bce: 0.010067, dice: 0.181135, loss: 0.095601\n",
      "saving best model\n",
      "0m 43s\n",
      "Epoch 8/39\n",
      "----------\n",
      "LR 0.0001\n",
      "train: bce: 0.009128, dice: 0.176201, loss: 0.092664\n",
      "val: bce: 0.008653, dice: 0.176254, loss: 0.092453\n",
      "saving best model\n",
      "0m 43s\n",
      "Epoch 9/39\n",
      "----------\n",
      "LR 0.0001\n",
      "train: bce: 0.008457, dice: 0.170643, loss: 0.089550\n",
      "val: bce: 0.008299, dice: 0.171656, loss: 0.089977\n",
      "saving best model\n",
      "0m 43s\n",
      "Epoch 10/39\n",
      "----------\n",
      "LR 0.0001\n",
      "train: bce: 0.007046, dice: 0.151076, loss: 0.079061\n",
      "val: bce: 0.005749, dice: 0.138535, loss: 0.072142\n",
      "saving best model\n",
      "0m 43s\n",
      "Epoch 11/39\n",
      "----------\n",
      "LR 0.0001\n",
      "train: bce: 0.004789, dice: 0.094846, loss: 0.049817\n",
      "val: bce: 0.004794, dice: 0.082758, loss: 0.043776\n",
      "saving best model\n",
      "0m 43s\n",
      "Epoch 12/39\n",
      "----------\n",
      "LR 0.0001\n",
      "train: bce: 0.003822, dice: 0.066693, loss: 0.035258\n",
      "val: bce: 0.004868, dice: 0.075574, loss: 0.040221\n",
      "saving best model\n",
      "0m 43s\n",
      "Epoch 13/39\n",
      "----------\n",
      "LR 0.0001\n",
      "train: bce: 0.003647, dice: 0.065981, loss: 0.034814\n",
      "val: bce: 0.005102, dice: 0.078447, loss: 0.041774\n",
      "0m 43s\n",
      "Epoch 14/39\n",
      "----------\n",
      "LR 0.0001\n",
      "train: bce: 0.003680, dice: 0.068849, loss: 0.036265\n",
      "val: bce: 0.004177, dice: 0.066650, loss: 0.035413\n",
      "saving best model\n",
      "0m 43s\n",
      "Epoch 15/39\n",
      "----------\n",
      "LR 0.0001\n",
      "train: bce: 0.003029, dice: 0.053153, loss: 0.028091\n",
      "val: bce: 0.003654, dice: 0.061158, loss: 0.032406\n",
      "saving best model\n",
      "0m 43s\n",
      "Epoch 16/39\n",
      "----------\n",
      "LR 0.0001\n",
      "train: bce: 0.002797, dice: 0.050167, loss: 0.026482\n",
      "val: bce: 0.003610, dice: 0.059508, loss: 0.031559\n",
      "saving best model\n",
      "0m 43s\n",
      "Epoch 17/39\n",
      "----------\n",
      "LR 0.0001\n",
      "train: bce: 0.002720, dice: 0.049958, loss: 0.026339\n",
      "val: bce: 0.003184, dice: 0.057431, loss: 0.030307\n",
      "saving best model\n",
      "0m 43s\n",
      "Epoch 18/39\n",
      "----------\n",
      "LR 0.0001\n",
      "train: bce: 0.002537, dice: 0.046737, loss: 0.024637\n",
      "val: bce: 0.003113, dice: 0.054996, loss: 0.029055\n",
      "saving best model\n",
      "0m 43s\n",
      "Epoch 19/39\n",
      "----------\n",
      "LR 0.0001\n",
      "train: bce: 0.002300, dice: 0.044468, loss: 0.023384\n",
      "val: bce: 0.002945, dice: 0.051255, loss: 0.027100\n",
      "saving best model\n",
      "0m 43s\n",
      "Epoch 20/39\n",
      "----------\n",
      "LR 0.0001\n",
      "train: bce: 0.002042, dice: 0.040555, loss: 0.021299\n",
      "val: bce: 0.002866, dice: 0.050504, loss: 0.026685\n",
      "saving best model\n",
      "0m 43s\n",
      "Epoch 21/39\n",
      "----------\n",
      "LR 0.0001\n",
      "train: bce: 0.001988, dice: 0.038980, loss: 0.020484\n",
      "val: bce: 0.002593, dice: 0.047394, loss: 0.024993\n",
      "saving best model\n",
      "0m 43s\n",
      "Epoch 22/39\n",
      "----------\n",
      "LR 0.0001\n",
      "train: bce: 0.001841, dice: 0.036638, loss: 0.019239\n",
      "val: bce: 0.002522, dice: 0.045939, loss: 0.024230\n",
      "saving best model\n",
      "0m 43s\n",
      "Epoch 23/39\n",
      "----------\n",
      "LR 0.0001\n",
      "train: bce: 0.001795, dice: 0.035693, loss: 0.018744\n",
      "val: bce: 0.002727, dice: 0.044743, loss: 0.023735\n",
      "saving best model\n",
      "0m 43s\n",
      "Epoch 24/39\n",
      "----------\n",
      "LR 0.0001\n",
      "train: bce: 0.001691, dice: 0.034025, loss: 0.017858\n",
      "val: bce: 0.002360, dice: 0.043020, loss: 0.022690\n",
      "saving best model\n",
      "0m 43s\n",
      "Epoch 25/39\n",
      "----------\n",
      "LR 1e-05\n",
      "train: bce: 0.001572, dice: 0.031303, loss: 0.016437\n",
      "val: bce: 0.002217, dice: 0.040832, loss: 0.021524\n",
      "saving best model\n",
      "0m 43s\n",
      "Epoch 26/39\n",
      "----------\n",
      "LR 1e-05\n",
      "train: bce: 0.001514, dice: 0.030473, loss: 0.015993\n",
      "val: bce: 0.002166, dice: 0.040488, loss: 0.021327\n",
      "saving best model\n",
      "0m 43s\n",
      "Epoch 27/39\n",
      "----------\n",
      "LR 1e-05\n",
      "train: bce: 0.001501, dice: 0.030128, loss: 0.015815\n",
      "val: bce: 0.002229, dice: 0.040340, loss: 0.021285\n",
      "saving best model\n",
      "0m 43s\n",
      "Epoch 28/39\n",
      "----------\n",
      "LR 1e-05\n",
      "train: bce: 0.001496, dice: 0.029890, loss: 0.015693\n",
      "val: bce: 0.002166, dice: 0.040157, loss: 0.021162\n",
      "saving best model\n",
      "0m 43s\n",
      "Epoch 29/39\n",
      "----------\n",
      "LR 1e-05\n",
      "train: bce: 0.001488, dice: 0.029740, loss: 0.015614\n",
      "val: bce: 0.002215, dice: 0.040059, loss: 0.021137\n",
      "saving best model\n",
      "0m 43s\n",
      "Epoch 30/39\n",
      "----------\n",
      "LR 1e-05\n",
      "train: bce: 0.001479, dice: 0.029537, loss: 0.015508\n",
      "val: bce: 0.002149, dice: 0.039748, loss: 0.020948\n",
      "saving best model\n",
      "0m 43s\n",
      "Epoch 31/39\n",
      "----------\n",
      "LR 1e-05\n",
      "train: bce: 0.001469, dice: 0.029364, loss: 0.015416\n",
      "val: bce: 0.002212, dice: 0.039819, loss: 0.021016\n",
      "0m 43s\n",
      "Epoch 32/39\n",
      "----------\n",
      "LR 1e-05\n",
      "train: bce: 0.001470, dice: 0.029170, loss: 0.015320\n",
      "val: bce: 0.002146, dice: 0.039689, loss: 0.020918\n",
      "saving best model\n",
      "0m 43s\n",
      "Epoch 33/39\n",
      "----------\n",
      "LR 1e-05\n",
      "train: bce: 0.001456, dice: 0.029055, loss: 0.015255\n",
      "val: bce: 0.002180, dice: 0.039492, loss: 0.020836\n",
      "saving best model\n",
      "0m 43s\n",
      "Epoch 34/39\n",
      "----------\n",
      "LR 1e-05\n",
      "train: bce: 0.001451, dice: 0.028900, loss: 0.015175\n",
      "val: bce: 0.002170, dice: 0.039412, loss: 0.020791\n",
      "saving best model\n",
      "0m 43s\n",
      "Epoch 35/39\n",
      "----------\n",
      "LR 1e-05\n",
      "train: bce: 0.001432, dice: 0.028700, loss: 0.015066\n",
      "val: bce: 0.002203, dice: 0.039768, loss: 0.020985\n",
      "0m 43s\n",
      "Epoch 36/39\n",
      "----------\n",
      "LR 1e-05\n",
      "train: bce: 0.001433, dice: 0.028581, loss: 0.015007\n",
      "val: bce: 0.002091, dice: 0.039245, loss: 0.020668\n",
      "saving best model\n",
      "0m 43s\n",
      "Epoch 37/39\n",
      "----------\n",
      "LR 1e-05\n",
      "train: bce: 0.001422, dice: 0.028358, loss: 0.014890\n",
      "val: bce: 0.002160, dice: 0.039272, loss: 0.020716\n",
      "0m 43s\n",
      "Epoch 38/39\n",
      "----------\n",
      "LR 1e-05\n",
      "train: bce: 0.001414, dice: 0.028230, loss: 0.014822\n",
      "val: bce: 0.002143, dice: 0.039213, loss: 0.020678\n",
      "0m 43s\n",
      "Epoch 39/39\n",
      "----------\n",
      "LR 1e-05\n",
      "train: bce: 0.001406, dice: 0.027994, loss: 0.014700\n",
      "val: bce: 0.002083, dice: 0.039034, loss: 0.020559\n",
      "saving best model\n",
      "0m 43s\n",
      "Best val loss: 0.020559\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.optim as optim\n",
    "from torch.optim import lr_scheduler\n",
    "import time\n",
    "import copy\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(device)\n",
    "\n",
    "num_class = 6\n",
    "\n",
    "model = pytorch_unet.UNet(num_class).to(device)\n",
    "\n",
    "# Observe that all parameters are being optimized\n",
    "optimizer_ft = optim.Adam(model.parameters(), lr=1e-4)\n",
    "\n",
    "exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=25, gamma=0.1)\n",
    "\n",
    "model = train_model(model, optimizer_ft, exp_lr_scheduler, num_epochs=40)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(3, 6, 192, 192)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAsQAAAKvCAYAAABtZtkaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3X+sbGd5H/rvUztwJUIFxPtavv5RG8shCml6AltOEIELBRJDURzCFbVVJU6DekAFqU1y1ZJQFW6vkKI2BCnKjclBWDZXiSGtQ7Fy3QYX0UCQKRwT1zEEg+0YcY6MvcERWElEYvu5f5zZYTjsHzN7ZvbM7PX5SFt75p21Zp61fR6/3/Puddaq7g4AAAzV31l2AQAAsEwCMQAAgyYQAwAwaAIxAACDJhADADBoAjEAAIO2sEBcVVdV1b1VdV9VvWVRnwMAALOoRVyHuKrOSfKFJK9IcirJp5Nc292fm/uHAQDADBa1Qnxlkvu6+4Hu/usk709y9YI+CwAADuzcBb3vhUm+PPb8VJIf3m3jqnK7PIbsq929sewipnHeeef1pZdeuuwyYCnuvPPOtepZ/cqQTdqviwrE+6qq40mOL+vzYYV8adkFTGK8Zy+55JKcPHlyyRXBclTVyvesfoUzJu3XRZ0ycTrJxWPPLxqN/a3uPtHdm929uaAagDka79mNjbVZHINB0q8wnUUF4k8nuaKqLquqpyS5JsmtC/osAAA4sIWcMtHdj1fVm5P8QZJzktzQ3Z9dxGcBAMAsFnYOcXffluS2Rb0/AADMgzvVAQAwaAIxAACDJhADADBoAjEAAIMmEAMAMGgCMQAAgyYQAwAwaAIxAACDJhADADBoAjEAAIMmEAMAMGgCMQAAgyYQAwAwaAIxAACDJhADADBoAjEAAIN24EBcVRdX1Uer6nNV9dmq+hej8bdX1emqumv09ar5lQsAAPN17gz7Pp7kF7v7M1X19CR3VtXto9fe1d2/Ont5AACwWAcOxN39UJKHRo8fq6o/TXLhvAoDAIDDMJdziKvq0iQ/lOR/jIbeXFV3V9UNVfXMeXwGAAAswsyBuKq+O8ktSf5ld38jyfVJLk9yLGdWkN+5y37Hq+pkVZ2ctQZg8cZ7dmtra9nlAHvQrzCdmQJxVX1XzoTh3+7u30uS7n64u5/o7ieTvCfJlTvt290nunuzuzdnqQE4HOM9u7GxsexygD3oV5jOLFeZqCTvTfKn3f1rY+MXjG32miT3HLw8AABYrFmuMvHCJD+d5E+q6q7R2C8nubaqjiXpJA8mecNMFQIAwALNcpWJP0pSO7x028HLAQCAw+VOdQAADJpADADAoAnEAAAMmkAMAMCgCcQAAAyaQAwAwKAJxAAADJpADADAoAnEAAAMmkAMAMCgCcQAAAyaQAwAwKAJxAAADJpADADAoAnEAAAMmkAMAMCgnTvrG1TVg0keS/JEkse7e7OqnpXkA0kuTfJgktd195/P+lkAADBv81ohfml3H+vuzdHztyT5SHdfkeQjo+cAALByFnXKxNVJbho9vinJTy7ocwAAYCbzCMSd5MNVdWdVHR+Nnd/dD40efyXJ+XP4HAAAmLuZzyFO8qPdfbqq/tckt1fV58df7O6uqj57p1F4Pn72OLCaxnv2kksuWXI1wF70K0xn5hXi7j49+v5Ikg8muTLJw1V1QZKMvj+yw34nuntz7LxjYIWN9+zGxsayywH2oF9hOjMF4qp6WlU9fftxkh9Lck+SW5NcN9rsuiQfmuVzAABgUWY9ZeL8JB+squ33+p3u/q9V9ekkv1tVr0/ypSSvm/FzAABgIWYKxN39QJJ/sMP415K8bJb3BgCAw+BOdQAADJpADADAoAnEAAAMmkAMAMCgCcQAAAyaQAwAwKAJxAAADJpADADAoAnEAAAMmkAMAMCgCcQAAAyaQAwAwKAJxAAADJpADADAoAnEAAAMmkAMAMCgnXvQHavqOUk+MDb07CT/NskzkvyzJFuj8V/u7tsOXCEAACzQgQNxd9+b5FiSVNU5SU4n+WCSf5rkXd39q3OpEAAAFmhep0y8LMn93f2lOb0fAAAcinkF4muS3Dz2/M1VdXdV3VBVz5zTZwAAwNzNHIir6ilJfiLJfxwNXZ/k8pw5neKhJO/cZb/jVXWyqk7OWgOweOM9u7W1tf8OwNLoV5jOPFaIX5nkM939cJJ098Pd/UR3P5nkPUmu3Gmn7j7R3ZvdvTmHGoAFG+/ZjY2NZZcD7EG/wnTmEYivzdjpElV1wdhrr0lyzxw+AwAAFuLAV5lIkqp6WpJXJHnD2PC/r6pjSTrJg2e9tjDdvetrVXUYJQBTeOyu3f8+/vRjTx5iJcB+9CtH3UyBuLv/Isn3nDX20zNVdLA69n1dKIbVsdfkuv26SRZWg35lCNb+TnX7heFptwMWa7/JddrtgMXRrwzFWv8JnjbkCsWwXNNOmiZZWB79ypD40wsAwKAJxAAADJpADADAoAnEAAAMmkAMAMCgCcQAAAzaWgfiaW+24eYcsFzTXrzfxf5hefQrQ7LWgTiZPOQKw7AaJp00Ta6wfPqVoVj7QJzsH3aFYVgt+02eJldYHfqVITh32QXMi9AL68UkCutDv3LUHYkVYgAAOCiBGACAQROIAQAYNIGYuejuZZcATKF+69XLLgGYkH5dvIkCcVXdUFWPVNU9Y2PPqqrbq+qLo+/PHI1XVf16Vd1XVXdX1fMWVTyrRSiG9WKShfWhXxdr0qtM3JjkN5K8b2zsLUk+0t2/UlVvGT3/10lemeSK0dcPJ7l+9P1ImiYEDuFKGN09iONkfR2/5fTE25547YULrGQ11G+9Ov2G3192GbAj/frt9OviTLRC3N0fS/LoWcNXJ7lp9PimJD85Nv6+PuOTSZ5RVRfMo9hV0t1Tr4geZJ91NIRjZP0cv+X0VJPrQfdZR1aeWDX6dXf6dTFmOYf4/O5+aPT4K0nOHz2+MMmXx7Y7NRo7MmYNfEMIjEM4RtbHrJOkSRYOj37dn36dv7n8o7o+k36mSkBVdbyqTlbVyXnUcFjmFfSGEBiHcIxDMt6zW1tbyy5nYvOaHE2yrBP9ql+Zzix3qnu4qi7o7odGp0Q8Mho/neTise0uGo19m+4+keREklTVWiSnvQLeXufN7rbfEM63HcIxDsV4z25ubq5Fz+41Ke51vuFu+x2/5fRKn6f4gsvfMfG2F73woh3HK85RPAr06+r367bxvr3j/rdOta9ziudnlhXiW5NcN3p8XZIPjY3/zOhqEz+S5Otjp1asrd1CbVXtG/j22mbVV1G369vv+yTvAYdpt0nyxGsv3HeS3GsbK08wf0Pt17P/EvuCy9/xbV+T0K/zMell125OckeS51TVqap6fZJfSfKKqvpikpePnifJbUkeSHJfkvck+edzr/qQ7RWGp7GOoXi75v2+72eVj5GjZ6/JdRrrOsnOg0mWwzLUfp0k8G5vs98qsH6d3aRXmbi2uy/o7u/q7ou6+73d/bXufll3X9HdL+/uR0fbdne/qbsv7+6/391rdY7wpA56GsC6nT4wjxXis98LluGgvzpdh1+5LopJlmXRr99ipfhwuFPdPnYKcbOG2p32X9WwOK8V4m2repwcHTutBs06Se60/6quOs2bSZZF0q+TecHl75joXGH9enACMXua5wrx2e8JrAeTLCzfdih2+sRizHKViUGa1ykPVbUWwXDSFeJ1OxWE4ZjXr1BPvPbCtV9lGnfRJy7ecfzL73v3IVcC36Jf9/a35xTf78oS82aFmD1NukI8j5VjAGB/01xikckIxOxp2hViK8YAwLoRiNmTFWIA4KgTiNmTFWIA4KgTiNmTFWIAWD3OI54vgZg9WSEGAI46gXhK81r5XJcVVCvErLt5XXrpKF7CCVaNfmVZBGL2ZIUYAObvjvvfuuwSGCMQ72MRt1lexO2gF8UKMetmEbdtXcTtZQH9yuoQiA/ooIFv3YKiFWKOioNOsn71CodvKP16x/1vPdBK8UH3Y3e1CgGtqpZfxD52+zlNE/zm8R4cSXd29+ayi5jG5uZmnzx5ctll7Gm3iXGalaJ5vAdHT1WtVc/qV/06ZJP2qxXiCe0WWrt731XfvbYRhmExdpsEj99yet9VpL22MbnC/OlXlu3cZRewTqpq12B7kJV2YRgW68RrL9x1ojzIr1ZNrrA4+pVl2neFuKpuqKpHquqesbH/UFWfr6q7q+qDVfWM0filVfVXVXXX6Ovdiyx+GeYVYoVhOBzzmhRNrrB4+pVlmeSUiRuTXHXW2O1JfqC7fzDJF5L80thr93f3sdHXG+dT5mqZNcwKw3C4Zp0cTa5wePQry7DvKRPd/bGquvSssQ+PPf1kkv9jvmWtvu1QO82pEoIwLM/2JDnNr15NrLAc+pXDNo9ziH8uyQfGnl9WVX+c5BtJ/k13f3wOn7GyhFxYLyZNWB/6lcMyUyCuqrcmeTzJb4+GHkpySXd/raqen+Q/V9Vzu/sbO+x7PMnxWT4fODzjPXvJJZcsuRpgL/oVpnPgy65V1c8meXWSf9Kj8wa6+5vd/bXR4zuT3J/ke3fav7tPdPfmOl3LEYZsvGc3NjaWXQ6wB/0K0zlQIK6qq5L8qyQ/0d1/OTa+UVXnjB4/O8kVSR6YR6EAALAI+54yUVU3J3lJkvOq6lSSt+XMVSWemuT20Tm0nxxdUeLFSf5dVf1NkieTvLG7H11Q7QAAMLNJrjJx7Q7D791l21uS3DJrUQAAcFjcuhkAgEETiAEAGDSBGACAQROIAQAYNIEYAIBBE4gBABg0gRgAgEETiAEAGDSBGACAQROIAQAYNIEYAIBBE4gBABg0gRgAgEETiAEAGDSBGACAQROIAQAYtH0DcVXdUFWPVNU9Y2Nvr6rTVXXX6OtVY6/9UlXdV1X3VtWPL6pwAACYh0lWiG9MctUO4+/q7mOjr9uSpKq+P8k1SZ472uc3q+qceRULAADztm8g7u6PJXl0wve7Osn7u/ub3f1nSe5LcuUM9QEAwELNcg7xm6vq7tEpFc8cjV2Y5Mtj25wajQEAwEo694D7XZ/k/07So+/vTPJz07xBVR1PcvyAn88R0N1T71NVC6iESYz37CWXXLLkaliGT73wpVPvc+UnPrqAStiPfkW/TudAK8Td/XB3P9HdTyZ5T751WsTpJBePbXrRaGyn9zjR3ZvdvXmQGoDDNd6zGxsbyy4H2IN+hekcKBBX1QVjT1+TZPsKFLcmuaaqnlpVlyW5IsmnZisRAAAWZ99TJqrq5iQvSXJeVZ1K8rYkL6mqYzlzysSDSd6QJN392ar63SSfS/J4kjd19xOLKZ1Vtn06hFMcYD3c9PEXJEmue9EdS64E2I9+nb99A3F3X7vD8Hv32P4dSd4xS1EAAHBY3KkOAIBBE4gBABg0gRgAgEETiAEAGDSBGACAQROIAQAYNIEYAIBB2/c6xHC27ZtuzHtbYDG2L+I/67bnvfzMlLHx30wdsCjz6NfrXnRHPvOCn8rz7vi9eZV15Pm/GkvnbnawHr768seTJP/o//r4kisB9nLTx1+Q593xv/zt8ys/8dElVrMeBGKmNkmAdetmWB2T3N7VrWBhNcytXz8xr4qGwTnEAAAMmkAMAMCgCcQAAAyaQAwAwKAJxAAADJpADADAoO0biKvqhqp6pKruGRv7QFXdNfp6sKruGo1fWlV/NfbauxdZPAAAzGqS6xDfmOQ3krxve6C7//H246p6Z5Kvj21/f3cfm1eBAACwSPsG4u7+WFVdutNrdeauC69L8g/nWxbrzg05YL24IQesD/06f7OeQ/yiJA939xfHxi6rqj+uqj+sqhfN+P4AALBQs966+dokN489fyjJJd39tap6fpL/XFXP7e5vnL1jVR1PcnzGzwcOyXjPXnLJJUuuBtiLfoXpHHiFuKrOTfJTST6wPdbd3+zur40e35nk/iTfu9P+3X2iuze7e/OgNQCHZ7xnNzY2ll0OsAf9CtOZ5ZSJlyf5fHef2h6oqo2qOmf0+NlJrkjywGwlAgDA4kxy2bWbk9yR5DlVdaqqXj966Zp8++kSSfLiJHePLsP2n5K8sbsfnWfBAAAwT5NcZeLaXcZ/doexW5LcMntZAABwONypDgCAQROIAQAYNIEYAIBBE4gBABg0gRgAgEETiAEAGDSBGACAQROIAQAYNIEYAIBBE4gBABg0gRgAgEGr7l52DamqrSR/keSry65lDs6L41gl63Acf6+7N5ZdxDSq6rEk9y67jjlYhz8fk3Ach2utela/rhzHcbgm6tdzD6OS/XT3RlWd7O7NZdcyK8exWo7Kcayge4/Cz/Wo/PlwHOxDv64Qx7GanDIBAMCgCcQAAAzaKgXiE8suYE4cx2o5Ksexao7Kz9VxrJajchyr5qj8XB3Hajkqx5FkRf5RHQAALMsqrRADAMChE4gBABg0gRgAgEETiAEAGDSBGACAQROIAQAYNIEYAIBBE4gBABg0gRgAgEETiAEAGDSBGACAQROIAQAYNIEYAIBBE4gBABg0gRgAgEETiAEAGDSBGACAQROIAQAYNIEYAIBBE4gBABg0gRgAgEETiAEAGDSBGACAQROIAQAYNIEYAIBBE4gBABg0gRgAgEETiAEAGDSBGACAQVtYIK6qq6rq3qq6r6resqjPAQCAWVR3z/9Nq85J8oUkr0hyKsmnk1zb3Z+b+4cBAMAMFrVCfGWS+7r7ge7+6yTvT3L1gj4LAAAObFGB+MIkXx57fmo0BgAAK+XcZX1wVR1Pcnz09PnLqgNWwFe7e2PZRexnvGef9rSnPf/7vu/7llwRLMedd9658j2rX+GMSft1UYH4dJKLx55fNBr7W919IsmJJKmq+Z/IDOvjS8suYBLjPbu5udknT55cckWwHFW18j2rX+GMSft1UadMfDrJFVV1WVU9Jck1SW5d0GcBAMCBLWSFuLsfr6o3J/mDJOckuaG7P7uIzwIAgFks7Bzi7r4tyW2Len8AAJgHd6oDAGDQBGIAAAZNIAYAYNAEYgAABk0gBgBg0ARiAAAGTSAGAGDQBGIAAAZNIAYAYNAEYgAABk0gBgBg0ARiAAAGTSAGAGDQBGIAAAZNIAYAYNAEYgAABu3AgbiqLq6qj1bV56rqs1X1L0bjb6+q01V11+jrVfMrFwAA5uvcGfZ9PMkvdvdnqurpSe6sqttHr72ru3919vIAAGCxDhyIu/uhJA+NHj9WVX+a5MJ5FQYAAIdhLucQV9WlSX4oyf8YDb25qu6uqhuq6pnz+AwAAFiEmQNxVX13kluS/Mvu/kaS65NcnuRYzqwgv3OX/Y5X1cmqOjlrDcDijffs1tbWsssB9qBfYTozBeKq+q6cCcO/3d2/lyTd/XB3P9HdTyZ5T5Ird9q3u09092Z3b85SA3A4xnt2Y2Nj2eUAe9CvMJ1ZrjJRSd6b5E+7+9fGxi8Y2+w1Se45eHkAALBYs1xl4oVJfjrJn1TVXaOxX05ybVUdS9JJHkzyhpkqBACABZrlKhN/lKR2eOm2g5cDAACHy53qAAAYNIEYAIBBE4gBABg0gRgAgEETiAEAGDSBGACAQROIAQAYNIEYAIBBE4gBABg0gRgAgEETiAEAGDSBGACAQROIAQAYNIEYAIBBE4gBABg0gRgAgEE7d9Y3qKoHkzyW5Ikkj3f3ZlU9K8kHklya5MEkr+vuP5/1swAAYN7mtUL80u4+1t2bo+dvSfKR7r4iyUdGzwEAYOUs6pSJq5PcNHp8U5KfXNDnHEh3L7sEYAr1W69edgnAhPQr62gegbiTfLiq7qyq46Ox87v7odHjryQ5fw6fM1dCMawXkyysD/3KuplHIP7R7n5eklcmeVNVvXj8xT6TPL8jfVbV8ao6WVUn51DDgQjFMLnxnt3a2lpODSZZmIh+henMHIi7+/To+yNJPpjkyiQPV9UFSTL6/sgO+53o7s2x846XQiiGyYz37MbGxtLqMMnC/vQrTGemQFxVT6uqp28/TvJjSe5JcmuS60abXZfkQ7N8zqIJxbBeTLKwPvQr62DWFeLzk/xRVf3PJJ9K8v91939N8itJXlFVX0zy8tHzlSYUw3oxycL60K+supkCcXc/0N3/YPT13O5+x2j8a939su6+ortf3t2PzqfcxRKKYb2YZGF96FdWmTvVnUUohvVikoX1oV9ZVQLxDoRiWC8mWVgf+pVVJBDvQiiG9WKShfWhX1k1AvEehGJYLyZZWB/6lVUiEO9DKIb1YpKF9aFfWRUC8QSEYlgvJllYH/qVVSAQT0gohvVikoX1oV9ZNoF4CkIxrBeTLKwP/coyCcRTEophvZhkYX3oV5ZFID4AoRjWi0kW1od+ZRkE4gMSimG9mGRhfehXDptAPAOhGNaLSRbWh37lMAnEMxKKYb2YZGF96FcOy7nLLmAZqmrZJQBT6Df8/rJLACakX1lHVogBABg0gRgAgEE78CkTVfWcJB8YG3p2kn+b5BlJ/lmSrdH4L3f3bQeuEAAAFujAgbi7701yLEmq6pwkp5N8MMk/TfKu7v7VuVQIAAALNK9TJl6W5P7u/tKc3g8AAA7FvALxNUluHnv+5qq6u6puqKpnzukzAABg7mYOxFX1lCQ/keQ/joauT3J5zpxO8VCSd+6y3/GqOllVJ2etAVi88Z7d2trafwdgafQrTGceK8SvTPKZ7n44Sbr74e5+orufTPKeJFfutFN3n+juze7enEMNwIKN9+zGxsayywH2oF9hOvMIxNdm7HSJqrpg7LXXJLlnDp8BAAALMdOd6qrqaUlekeQNY8P/vqqOJekkD571GgAArJSZAnF3/0WS7zlr7KdnqggAAA6RO9UBADBoAjEAAIMmEAMAMGgCMQAAgyYQAwAwaAIxAACDJhADADBoAjEAAIMmEAMAMGgCMQAAgyYQAwAwaAIxAACDJhADADBoAjEAAIMmEAMAMGgCMQAAgzZRIK6qG6rqkaq6Z2zsWVV1e1V9cfT9maPxqqpfr6r7quruqnreoooHAIBZTbpCfGOSq84ae0uSj3T3FUk+MnqeJK9McsXo63iS62cvEwAAFmOiQNzdH0vy6FnDVye5afT4piQ/OTb+vj7jk0meUVUXzKNYAACYt1nOIT6/ux8aPf5KkvNHjy9M8uWx7U6NxgAAYOXM5R/VdXcn6Wn2qarjVXWyqk7OowZgscZ7dmtra9nlAHvQrzCdWQLxw9unQoy+PzIaP53k4rHtLhqNfZvuPtHdm929OUMNwCEZ79mNjY1llwPsQb/CdGYJxLcmuW70+LokHxob/5nR1SZ+JMnXx06tAACAlXLuJBtV1c1JXpLkvKo6leRtSX4lye9W1euTfCnJ60ab35bkVUnuS/KXSf7pnGsGAIC5mSgQd/e1u7z0sh227SRvmqUoAAA4LO5UBwDAoAnEAAAMmkAMAMCgCcQAAAyaQAwAwKAJxAAADJpADADAoAnEAAAMmkAMAMCgCcQAAAyaQAwAwKAJxAAADJpADADAoAnEAAAM2rnLLgDmqbtTVXP7DizWCy5/x9ze64773zq39wK+01HuVyvEHCnbIXZe3wGAo2/fQFxVN1TVI1V1z9jYf6iqz1fV3VX1wap6xmj80qr6q6q6a/T17kUWD2fr7rl+BwCOvklWiG9MctVZY7cn+YHu/sEkX0jyS2Ov3d/dx0Zfb5xPmTAZK8QAwLT2DcTd/bEkj5419uHufnz09JNJLlpAbTA1K8QAwLTmcQ7xzyX5L2PPL6uqP66qP6yqF83h/WFiVogBgGnNFIir6q1JHk/y26Ohh5Jc0t0/lOQXkvxOVf3dXfY9XlUnq+rkLDXAOCvEizPes1tbW8suB9iDfoXpHDgQV9XPJnl1kn/So/TQ3d/s7q+NHt+Z5P4k37vT/t19ors3u3vzoDXA2awQL854z25sbCy7HGAP+hWmc6BAXFVXJflXSX6iu/9ybHyjqs4ZPX52kiuSPDCPQmESVogBgGnte2OOqro5yUuSnFdVp5K8LWeuKvHUJLePVtI+ObqixIuT/Luq+pskTyZ5Y3c/uuMbwwJYIQYAprVvIO7ua3cYfu8u296S5JZZi4KDcqc6AGBa7lTHkWKFGACYlkDMkeIcYgBgWgIxR4oVYgBgWgIxR4oVYgBgWgIxR4oVYgBgWgIxR4oVYgBgWgIxR4oVYgBgWgIxR4oVYgBgWvvemAPWiRViWC933P/WZZcATOgo96sVYgAABk0gBgBg0ARiAAAGTSAGAGDQBGIAAAZNIAYAYNAEYgAABm3fQFxVN1TVI1V1z9jY26vqdFXdNfp61dhrv1RV91XVvVX144sqHAAA5mGSFeIbk1y1w/i7uvvY6Ou2JKmq709yTZLnjvb5zao6Z17FcrR0tzvCwRq56eMvyE0ff8GyywAmoF+ns++d6rr7Y1V16YTvd3WS93f3N5P8WVXdl+TKJHccuMKBmiYouqsaLN/xW05PvO2J1164wEqA/ehXzjbLrZvfXFU/k+Rkkl/s7j9PcmGST45tc2o0xoQOsmK6vY9gDIdvmon17H1MtHC49Cu7Oeg/qrs+yeVJjiV5KMk7p32DqjpeVSer6uQBazhyZj19wOkHLNJ4z25tbS27nJVwkMl1nvvDbvTrd9Kv7OVAK8Td/fD246p6T5LfHz09neTisU0vGo3t9B4nkpwYvcegk9w8g6zVYhZlvGc3NzcH3bPznBitPrEI+vVb9CuTOFAgrqoLuvuh0dPXJNm+AsWtSX6nqn4tyf+W5Iokn5q5yiNsvzC8V7Dda9/uFophAfabXPeaKPfa9/gtp02yMGf6lUntG4ir6uYkL0lyXlWdSvK2JC+pqmNJOsmDSd6QJN392ar63SSfS/J4kjd19xOLKf1omyTMbm/jVAlYvkkmx+1t/OoVlku/crZJrjJx7Q7D791j+3ckeccsRR2WZV/JYbfPn/azqmrH97JKzFHzqRe+dOJtr/zER+f++btNjNOuFJ147YU7vpdVJ44S/co6cae6JZlXGN5vP6vHMB/zmlz3289qFMxOvzItgXiFzLqaazUYDtesq0NWl+Dw6Ff2IhAvwU6rtvMKszu9j1VimM1Oq0Dzmhx3eh+rTnBw+pWDEIgBABi0We5Ux5ytNTNaAAAZK0lEQVTM+1SH3f6R3aId9DOn3c+pISzbvH91uts/2lm0mz7+gkPZ77oX3XGgz4F50K/6dRIC8YRmuV4wcLjOe9WDeeyuvX8B9vRjTx5SNcB+xq9IceUnPvod/atfWTSBeE5c4mz6vxS4qx6LcN6rHpxou8fu+juDn2SnXQnaXmka6goSi7Xduw+847Js/KNvf02/6tdFE4jnSMCD5Zk0CI/bXoUa+kQLyzTeuxv/6Eu7bqdfWSSBGABYiu0wvFcQhsPgKhML4DJncLgOsjo8br/zjYHp7Xf3uYOGYf3KIgx6hXiR1+x1TjHM304T7LwmR+cowvyd3bPb/3hu1pVh/cq8+WvWCpj3irIValiseV9yyYX9GZLDPk1CvzIJgRgAOBTbK8bOGWbVCMRLsOhTNSb5PGByi7xd6yJvMwuraNFhWL9yEALxCpk1FDtVAg7XrJOsX73C4dGv7EUgXpLdVm3nfftjq8MwH7utAh10ktxtP6tNMDv9yrT2DcRVdUNVPVJV94yNfaCq7hp9PVhVd43GL62qvxp77d2LLH7dzSsUC8NwOOY1yZpcYfH0K9OY5LJrNyb5jSTv2x7o7n+8/biq3pnk62Pb39/dx+ZV4FBNctc7p0jA6tieNPeaJP3KFVaDfuVs+wbi7v5YVV2602t1Jq29Lsk/nG9Zw1FVewbbg4bedVgdXoca4WwnXnvhnhPlQSfRdVhtuu5Fdyy7BJiKfmVSs96Y40VJHu7uL46NXVZVf5zkG0n+TXd/fMbPOFT7BdRp3mfabQ/7c+EoePqxJ+dyc45pLvK/PRnOYwVpHSZWWDX6lXmbdRa5NsnNY88fSnJJd/9Qkl9I8jtV9Xd32rGqjlfVyao6OWMNczdrqDzo/sv6XJjEeM9ubW0tu5xvM+sdqw66/6yTo8mVRTnK/XpQ+pW91CSrkqNTJn6/u39gbOzcJKeTPL+7T+2y339P8n92956ht6pW7mTYWVZr5xFMp/l8QXjt3dndm8suYhqbm5t98uRq/V12llXieUzQ06w+mVjXW1WtVc+uYr8mB+9Z/co0Ju3XWU6ZeHmSz4+H4araSPJodz9RVc9OckWSB2b4jKU5yGkM8wymQi5MZ3uSnGaSnedKlUkTpjPt6U76lUWa5LJrNye5I8lzqupUVb1+9NI1+fbTJZLkxUnuHl2G7T8leWN3PzrPgg/bpMFUgIXVMOmkuaxf2wLfol9ZFZNcZeLaXcZ/doexW5LcMntZq0XYhfVi8oT1oV9ZBe5UBwDAoAnEAAAMmkAMAMCgCcQAAAyaQAwAwKAJxAAADJpADADAoAnEAAAMmkAMAMCgCcQAAAyaQAwAwKAJxAAADFp197JrSFVtJfmLJF9ddi1zcF4cxypZh+P4e929sewiplFVjyW5d9l1zME6/PmYhOM4XGvVs/p15TiOwzVRv557GJXsp7s3qupkd28uu5ZZOY7VclSOYwXdexR+rkflz4fjYB/6dYU4jtXklAkAAAZNIAYAYNBWKRCfWHYBc+I4VstROY5Vc1R+ro5jtRyV41g1R+Xn6jhWy1E5jiQr8o/qAABgWVZphRgAAA6dQAwAwKAJxAAADJpADADAoAnEAAAMmkAMAMCgCcQAAAyaQAwAwKAJxAAADJpADADAoAnEAAAMmkAMAMCgCcQAAAyaQAwAwKAJxAAADJpADADAoAnEAAAMmkAMAMCgCcQAAAyaQAwAwKAJxAAADJpADADAoAnEAAAMmkAMAMCgCcQAAAyaQAwAwKAJxAAADJpADADAoAnEAAAM2sICcVVdVVX3VtV9VfWWRX0OAADMorp7/m9adU6SLyR5RZJTST6d5Nru/tzcPwwAAGawqBXiK5Pc190PdPdfJ3l/kqsX9FkAAHBg5y7ofS9M8uWx56eS/PD4BlV1PMnx0dPnL6gOWAdf7e6NZRexn/GefdrTnvb87/u+71tyRbAcd95558r3rH6FMybt10UF4n1194kkJ5KkquZ/3gasjy8tu4BJjPfs5uZmnzx5cskVwXJU1cr3rH6FMybt10WdMnE6ycVjzy8ajQEAwEpZVCD+dJIrquqyqnpKkmuS3LqgzwIAgANbyCkT3f14Vb05yR8kOSfJDd392UV8FgAAzGJh5xB3921JblvU+wMAwDy4Ux0AAIMmEAMAMGgCMQAAgyYQAwAwaAIxAACDJhADADBoAjEAAIMmEAMAMGgCMQAAgyYQAwAwaAIxAACDJhADADBoAjEAAIMmEAMAMGgCMQAAgyYQAwAwaAcOxFV1cVV9tKo+V1Wfrap/MRp/e1Wdrqq7Rl+vml+5AAAwX+fOsO/jSX6xuz9TVU9PcmdV3T567V3d/auzlwcAAIt14EDc3Q8leWj0+LGq+tMkF86rMAAAOAxzOYe4qi5N8kNJ/sdo6M1VdXdV3VBVz9xln+NVdbKqTs6jBmCxxnt2a2tr2eUAe9CvMJ2ZA3FVfXeSW5L8y+7+RpLrk1ye5FjOrCC/c6f9uvtEd2929+asNQCLN96zGxsbyy4H2IN+henMFIir6rtyJgz/dnf/XpJ098Pd/UR3P5nkPUmunL1MAABYjFmuMlFJ3pvkT7v718bGLxjb7DVJ7jl4eQAAsFizXGXihUl+OsmfVNVdo7FfTnJtVR1L0kkeTPKGmSoEAIAFmuUqE3+UpHZ46baDlwMAAIfLneoAABg0gRgAgEETiAEAGDSBGACAQROIAQAYNIEYAIBBE4gBABg0gRgAgEETiAEAGDSBGACAQROIAQAYNIEYAIBBE4gBABg0gRgAgEETiAEAGDSBGACAQTt31jeoqgeTPJbkiSSPd/dmVT0ryQeSXJrkwSSv6+4/n/WzAABg3ua1QvzS7j7W3Zuj529J8pHuviLJR0bPAQBg5SzqlImrk9w0enxTkp9c0OcAAMBM5hGIO8mHq+rOqjo+Gju/ux8aPf5KkvPP3qmqjlfVyao6OYcagAUb79mtra1llwPsQb/CdOYRiH+0u5+X5JVJ3lRVLx5/sbs7Z0Jzzho/0d2bY6dZACtsvGc3NjaWXQ6wB/0K05k5EHf36dH3R5J8MMmVSR6uqguSZPT9kVk/BwAAFmGmQFxVT6uqp28/TvJjSe5JcmuS60abXZfkQ7N8DgAALMqsl107P8kHq2r7vX6nu/9rVX06ye9W1euTfCnJ62b8HACAI+UFl79j6n3uuP+tC6iEmQJxdz+Q5B/sMP61JC+b5b1Zru5OVc38HTgcB5lYz2aihcMxS79u76tf58ud6tjRdpid9TsAwKoTiNnRmYuDzP4dAGDVCcTsyAoxAKyueZwmxbcIxOzICjEAMBQCMTuyQgwADIVAzI6sEAMAQyEQsyMrxADAUAjE7MgKMQAwFAIxO7JCDAAMhUDMjqwQAwBDIRCzIyvEAMBQCMTsyAoxADAUAjE7skIMAAyFQMyOrBADwOq64/63LruEI0UgZkdWiAGAoRCI2ZEVYgBgKM496I5V9ZwkHxgbenaSf5vkGUn+WZKt0fgvd/dtB66QpbBCDOvFr09hfejX1XPgFeLuvre7j3X3sSTPT/KXST44evld268dlTBsxRPWS/3Wq5ddAjAh/cqyzeuUiZclub+7vzSn91tJQjGsF5MsrA/9yjLNKxBfk+Tmsedvrqq7q+qGqnrmTjtU1fGqOllVJ+dUw6EQihmq8Z7d2traf4cVYZJliPQrTGfmQFxVT0nyE0n+42jo+iSXJzmW5KEk79xpv+4+0d2b3b05aw2HTShmiMZ7dmNjY9nlTMUky9DoV5jOPFaIX5nkM939cJJ098Pd/UR3P5nkPUmunMNnrByhGNaLSRbWh37lsM0jEF+bsdMlquqCsddek+SeOXzGShKKYb2YZGF96FcO00yBuKqeluQVSX5vbPjfV9WfVNXdSV6a5Odn+YxVJxTDejHJwvrQrxyWmQJxd/9Fd39Pd399bOynu/vvd/cPdvdPdPdDs5e52oRiWC8mWVgf+pXD4E51cyIUw3oxycL60K8smkA8R0IxrBeTLKwP/coiCcRzJhTDejHJwvrQryyKQLwAQjGsF5MsrA/9yiIIxAsiFMN6McnC+tCvzJtAvEBCMawXkyysD/3KPAnECyYUw3oxycL60K/Mi0B8CIRiWC8mWVgf+pV5EIgPiVAM68UkC+tDvzIrgfgQCcWwXkyysD70K7M4d9kFrIuqWnYJwBT6Db+/7BKACelXls0KMQAAgyYQAwAwaAIxAACDJhADADBoAjEAAIM2USCuqhuq6pGqumds7FlVdXtVfXH0/Zmj8aqqX6+q+6rq7qp63qKKBwCAWU162bUbk/xGkveNjb0lyUe6+1eq6i2j5/86ySuTXDH6+uEk14++swR7XfvYpeRg9Tx21+7rFE8/9uQhVgJMYree1a/rZaJA3N0fq6pLzxq+OslLRo9vSvLfcyYQX53kfX0miX2yqp5RVRd090PzKJjvdNAbfnS3UAxLcPyW0zuO/9gzfy3JzydJfvxZ7/qO1x+76++YZOGQnd2vJ157YZK9//K6/bp+XR+znEN8/ljI/UqS80ePL0zy5bHtTo3Gvk1VHa+qk1V1coYaBm/Wu9+5ex6TGu/Zra2tZZeztvYOw9/yB4/+/I7b7TcJQ6Jf52Wnfj1+y+mJ+1C/ro+5/JcarQZPlay6+0R3b3b35jxqGKKdwmxVWfVlIcZ7dmNjY9nlrKWdJtcTr70wJ1574Y4rwruFYtiPfp3dXv3K0TPLrZsf3j4VoqouSPLIaPx0kovHtrtoNMYcnR2GhWBYbbv92nXcdigeD8J/8OjP7xiWgcWZpF85WmZZIb41yXWjx9cl+dDY+M+MrjbxI0m+7vzh+RKGYb1MO7meHYCtFMPhEYaHadLLrt2c5I4kz6mqU1X1+iS/kuQVVfXFJC8fPU+S25I8kOS+JO9J8s/nXjV/SxiG9TLp5GpVGJZPGB6OSa8yce0uL71sh207yZtmKQoAAA6Lf/64xqwOw3qZdrXJKjEsz1796nJqR49AfERNGpaFalgNk06wJmJYH/p1fQjER9h+YVcYhtWy3+RpcoXVoV+PFoF4jU1yU43t6xKfHX6FYTh8u92UY9zTjz35t1+/eP+X8+E//4V8+M9/weQKh2yaft1pnPUyy3WIAQAGTwBef1aI19ykt152i2ZYDZOsOk2zHbA40/Tr9hfrSSBeQ2ef7rBf2HUjD1ius/+1+n6TphsDwPLM2q+sJ4F4Te0Uis8OvjuNCcOwHDtNsmdPpDuNCcNw+PTr8DiHeI1V1Y4heK/tgeU58doLd5xU99oeWA79OiwC8ZrbKRTvtt062D6Ww653mnOs1+VnyWraaZLdbbt1cNPHX5Akue5Fdxzq537qhS+deNsrP/HRBVbCUaZf52Md+lUgPgIENFgv6zJ5Avp1KJxDDADAoAnEAAAMmkAMAMCgCcQAAAyaQAwAwKDtG4ir6oaqeqSq7hkb+w9V9fmquruqPlhVzxiNX1pVf1VVd42+3r3I4gEAYFaTrBDfmOSqs8ZuT/ID3f2DSb6Q5JfGXru/u4+Nvt44nzIBAGAx9g3E3f2xJI+eNfbh7n589PSTSS5aQG0AALBw8ziH+OeS/Jex55dV1R9X1R9W1Yt226mqjlfVyao6OYcagAUb79mtra1llwPsQb/CdGrC2/5emuT3u/sHzhp/a5LNJD/V3V1VT03y3d39tap6fpL/nOS53f2Nfd5/8vvmsnamuS3yQRyBO/Xd2d2byy5iGpubm33ypL/LHlXbt3ddlMO+bey8VdVa9ax+Pdr0694m7dcDrxBX1c8meXWSf9KjxNPd3+zur40e35nk/iTfe9DPAACARTv3IDtV1VVJ/lWS/727/3JsfCPJo939RFU9O8kVSR6YS6WsrWlWcLdXk4/Aqi+srWlWhLZXp9Z9FQnWlX6dj30DcVXdnOQlSc6rqlNJ3pYzV5V4apLbR8Hlk6MrSrw4yb+rqr9J8mSSN3b3ozu+MQAArIB9A3F3X7vD8Ht32faWJLfMWhQAABwWd6oDAGDQBGIAAAZNIAYAYNAEYgAABk0gBgBg0ARiAAAGTSAGAGDQBGIAAAbtQLduhkVxy2ZYL24BC+tDv+7OCjEAAIMmEAMAMGgCMQAAgyYQAwAwaAIxAACDJhADADBoAjEAAIO2byCuqhuq6pGqumds7O1Vdbqq7hp9vWrstV+qqvuq6t6q+vFFFQ4AAPMwyQrxjUmu2mH8Xd19bPR1W5JU1fcnuSbJc0f7/GZVnTOvYgEAYN72DcTd/bEkj074flcneX93f7O7/yzJfUmunKE+AABYqFnOIX5zVd09OqXimaOxC5N8eWybU6Ox71BVx6vqZFWdnKEG4JCM9+zW1tayywH2oF9hOgcNxNcnuTzJsSQPJXnntG/Q3Se6e7O7Nw9YA3CIxnt2Y2Nj2eUAe9CvMJ0DBeLufri7n+juJ5O8J986LeJ0kovHNr1oNAYAACvpQIG4qi4Ye/qaJNtXoLg1yTVV9dSquizJFUk+NVuJAACwOOfut0FV3ZzkJUnOq6pTSd6W5CVVdSxJJ3kwyRuSpLs/W1W/m+RzSR5P8qbufmIxpQMAwOz2DcTdfe0Ow+/dY/t3JHnHLEUBAMBhcac6AAAGTSAGAGDQBGIAAAZNIAYAYNAEYgAABk0gBgBg0ARiAAAGTSAGAGDQBGIAAAZNIAYAYNAEYgAABk0gBgBg0ARiAAAGTSAGAGDQBGIAAAZNIAYAYND2DcRVdUNVPVJV94yNfaCq7hp9PVhVd43GL62qvxp77d2LLB4AAGZ17gTb3JjkN5K8b3ugu//x9uOqemeSr49tf393H5tXgQAAsEj7BuLu/lhVXbrTa1VVSV6X5B/OtywAADgcs55D/KIkD3f3F8fGLquqP66qP6yqF+22Y1Udr6qTVXVyxhqAQzDes1tbW8suB9iDfoXpzBqIr01y89jzh5Jc0t0/lOQXkvxOVf3dnXbs7hPdvdndmzPWAByC8Z7d2NhYdjnAHvQrTOfAgbiqzk3yU0k+sD3W3d/s7q+NHt+Z5P4k3ztrkQAAsCizrBC/PMnnu/vU9kBVbVTVOaPHz05yRZIHZisRAAAWZ5LLrt2c5I4kz6mqU1X1+tFL1+TbT5dIkhcnuXt0Gbb/lOSN3f3oPAsGAIB5muQqE9fuMv6zO4zdkuSW2csCAIDD4U51AAAMmkAMAMCgCcQAAAyaQAwAwKAJxAAADJpADADAoAnEAAAMmkAMAMCgCcQAAAyaQAwAwKAJxAAADFp197JrSFVtJfmLJF9ddi1zcF4cxypZh+P4e929sewiplFVjyW5d9l1zME6/PmYhOM4XGvVs/p15TiOwzVRv557GJXsp7s3qupkd28uu5ZZOY7VclSOYwXdexR+rkflz4fjYB/6dYU4jtXklAkAAAZNIAYAYNBWKRCfWHYBc+I4VstROY5Vc1R+ro5jtRyV41g1R+Xn6jhWy1E5jiQr8o/qAABgWVZphRgAAA7d0gNxVV1VVfdW1X1V9ZZl1zONqnqwqv6kqu6qqpOjsWdV1e1V9cXR92cuu86zVdUNVfVIVd0zNrZj3XXGr4/++9xdVc9bXuXfbpfjeHtVnR79N7mrql419tovjY7j3qr68eVUvf707OHTs3r2oPTr4dOv69mvSw3EVXVOkv8nySuTfH+Sa6vq+5dZ0wG8tLuPjV165C1JPtLdVyT5yOj5qrkxyVVnje1W9yuTXDH6Op7k+kOqcRI35juPI0neNfpvcqy7b0uS0Z+ra5I8d7TPb47+/DEFPbs0N0bP6tkp6deluTH6de36ddkrxFcmua+7H+juv07y/iRXL7mmWV2d5KbR45uS/OQSa9lRd38syaNnDe9W99VJ3tdnfDLJM6rqgsOpdG+7HMdurk7y/u7+Znf/WZL7cubPH9PRs0ugZ/XsAenXJdCv69mvyw7EFyb58tjzU6OxddFJPlxVd1bV8dHY+d390OjxV5Kcv5zSprZb3ev43+jNo1893TD267R1PI5VtO4/Rz27mvTsYqz7z1C/rqYj2a/LDsTr7ke7+3k58yuPN1XVi8df7DOX8Fi7y3isa90j1ye5PMmxJA8leedyy2HF6NnVo2fZjX5dPUe2X5cdiE8nuXjs+UWjsbXQ3adH3x9J8sGc+fXAw9u/7hh9f2R5FU5lt7rX6r9Rdz/c3U9095NJ3pNv/cpmrY5jha31z1HPrh49u1Br/TPUr6vnKPfrsgPxp5NcUVWXVdVTcuaE7FuXXNNEquppVfX07cdJfizJPTlT/3Wjza5L8qHlVDi13eq+NcnPjP4l7I8k+frYr31WzlnnXr0mZ/6bJGeO45qqempVXZYz/4DhU4dd3xGgZ1eHnmU/+nV16NdV191L/UryqiRfSHJ/krcuu54p6n52kv85+vrsdu1Jvidn/gXpF5P8tyTPWnatO9R+c878quNvcuY8n9fvVneSypl/pXx/kj9Jsrns+vc5jv93VOfdOdOgF4xt/9bRcdyb5JXLrn9dv/TsUmrXs3r2oD9z/Xr4tevXNexXd6oDAGDQln3KBAAALJVADADAoAnEAAAMmkAMAMCgCcQAAAyaQAwAwKAJxAAADJpADADAoP3/1Wo52AeZZucAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 864x864 with 9 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# prediction\n",
    "\n",
    "import math\n",
    "\n",
    "model.eval()   # Set model to evaluate mode\n",
    "\n",
    "test_dataset = SimDataset(3, transform = trans)\n",
    "test_loader = DataLoader(test_dataset, batch_size=3, shuffle=False, num_workers=0)\n",
    "        \n",
    "inputs, labels = next(iter(test_loader))\n",
    "inputs = inputs.to(device)\n",
    "labels = labels.to(device)\n",
    "\n",
    "pred = model(inputs)\n",
    "\n",
    "pred = pred.data.cpu().numpy()\n",
    "print(pred.shape)\n",
    "\n",
    "# Change channel-order and make 3 channels for matplot\n",
    "input_images_rgb = [reverse_transform(x) for x in inputs.cpu()]\n",
    "\n",
    "# Map each channel (i.e. class) to each color\n",
    "target_masks_rgb = [helper.masks_to_colorimg(x) for x in labels.cpu().numpy()]\n",
    "pred_rgb = [helper.masks_to_colorimg(x) for x in pred]\n",
    "\n",
    "helper.plot_side_by_side([input_images_rgb, target_masks_rgb, pred_rgb])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:py36]",
   "language": "python",
   "name": "conda-env-py36-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
